How to use the asteroid.torch_utils.pad_x_to_y function in asteroid

To help you get started, we’ve selected a few asteroid examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mpariente / AsSteroid / egs / dns_challenge / baseline / model.py View on Github external
def denoise(self, x):
        estimate_stft = self(x)
        wav = self.decoder(estimate_stft)
        return torch_utils.pad_x_to_y(wav, x)
github mpariente / AsSteroid / egs / whamr / TasNet / model.py View on Github external
def forward(self, x):
        batch_size = x.shape[0]
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        tf_rep = self.encode(x)
        to_sep = self.bn_layer(tf_rep)
        est_masks = self.masker(to_sep.transpose(-1, -2)).transpose(-1, -2)
        est_masks = est_masks.view(batch_size, self.n_src, self.n_filters, -1)
        masked_tf_rep = tf_rep.unsqueeze(1) * est_masks
        return torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), x)
github mpariente / AsSteroid / asteroid / models / base_models.py View on Github external
Returns:
            torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
        """
        # Handle 1D, 2D or n-D inputs
        was_one_d = False
        if wav.ndim == 1:
            was_one_d = True
            wav = wav.unsqueeze(0).unsqueeze(1)
        if wav.ndim == 2:
            wav = wav.unsqueeze(1)
        # Real forward
        tf_rep = self.encoder(wav)
        est_masks = self.masker(tf_rep)
        masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
        out_wavs = torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), wav)
        if was_one_d:
            return out_wavs.squeeze(0)
        return out_wavs
github mpariente / AsSteroid / egs / wsj0-mix / DeepClustering / model.py View on Github external
def separate(self, x):
        """ Separate with mask-inference head, output waveforms """
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        tf_rep = self.encoder(x)
        proj, mask_out = self.masker(take_mag(tf_rep))
        masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out)
        wavs = torch_utils.pad_x_to_y(self.decoder(masked), x)
        dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked,
                       proj=proj)
        return wavs, dic_out
github mpariente / AsSteroid / egs / fuss / baseline / model.py View on Github external
def forward(self, x, bg=None):
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        tf_rep = self.encoder(x)
        # Concat ReIm and Mag input
        if self.learnable_scaling:
            est_masks, weights = self.masker(transforms.take_cat(tf_rep))
        else:
            est_masks = self.masker(transforms.take_cat(tf_rep))
        # Note : this is equivalent to ReIm masking for STFT
        masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
        out_wavs = self.decoder(masked_tf_rep)
        # Mixture consistency (weights are not learned but based on power)
        # Estimates should sum up to the targets only
        if bg is None:
            return pad_x_to_y(out_wavs, x)
        if len(bg.shape) == 2:
            bg = bg.unsqueeze(1)
        if self.learnable_scaling:
            out_wavs = consistency.mixture_consistency(x - bg, out_wavs, src_weights=weights)
        else:
            out_wavs = consistency.mixture_consistency(x - bg, out_wavs)
        return pad_x_to_y(out_wavs, x)
github mpariente / AsSteroid / egs / fuss / baseline / model.py View on Github external
else:
            est_masks = self.masker(transforms.take_cat(tf_rep))
        # Note : this is equivalent to ReIm masking for STFT
        masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
        out_wavs = self.decoder(masked_tf_rep)
        # Mixture consistency (weights are not learned but based on power)
        # Estimates should sum up to the targets only
        if bg is None:
            return pad_x_to_y(out_wavs, x)
        if len(bg.shape) == 2:
            bg = bg.unsqueeze(1)
        if self.learnable_scaling:
            out_wavs = consistency.mixture_consistency(x - bg, out_wavs, src_weights=weights)
        else:
            out_wavs = consistency.mixture_consistency(x - bg, out_wavs)
        return pad_x_to_y(out_wavs, x)