How to use the delve.torch_utils.TorchCovarianceMatrix function in delve

To help you get started, we’ve selected a few delve examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github delve-team / delve / delve / torchcallback.py View on Github external
shape = activations_batch.shape
                reshaped_batch = activations_batch.reshape(shape[0], shape[1], shape[2] * shape[3])
                activations_batch, _ = torch.max(reshaped_batch, dim=2)  # channel median
            elif self.conv_method == 'mean':
                activations_batch = torch.mean(activations_batch, dim=(2, 3))
            elif self.conv_method == 'flatten':
                activations_batch = activations_batch.view(activations_batch.size(0), -1)
            elif self.conv_method == 'channelwise':
                reshaped_batch: torch.Tensor = activations_batch.permute([1, 0, 2, 3])
                shape = reshaped_batch.shape
                reshaped_batch: torch.Tensor = reshaped_batch.flatten(1)
                reshaped_batch: torch.Tensor = reshaped_batch.permute([1, 0])
                activations_batch = reshaped_batch

        if layer.name not in self.logs[f'{training_state}-{stat}']:
            self.logs[f'{training_state}-{stat}'][layer.name] = TorchCovarianceMatrix(device=self.device)

        self.logs[f'{training_state}-{stat}'][layer.name].update(activations_batch, lstm_ae)