Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def denoise(self, x):
estimate_stft = self(x)
wav = self.decoder(estimate_stft)
return torch_utils.pad_x_to_y(wav, x)
def forward(self, x):
batch_size = x.shape[0]
if len(x.shape) == 2:
x = x.unsqueeze(1)
tf_rep = self.encode(x)
to_sep = self.bn_layer(tf_rep)
est_masks = self.masker(to_sep.transpose(-1, -2)).transpose(-1, -2)
est_masks = est_masks.view(batch_size, self.n_src, self.n_filters, -1)
masked_tf_rep = tf_rep.unsqueeze(1) * est_masks
return torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), x)
Returns:
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
"""
# Handle 1D, 2D or n-D inputs
was_one_d = False
if wav.ndim == 1:
was_one_d = True
wav = wav.unsqueeze(0).unsqueeze(1)
if wav.ndim == 2:
wav = wav.unsqueeze(1)
# Real forward
tf_rep = self.encoder(wav)
est_masks = self.masker(tf_rep)
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
out_wavs = torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), wav)
if was_one_d:
return out_wavs.squeeze(0)
return out_wavs
def separate(self, x):
""" Separate with mask-inference head, output waveforms """
if len(x.shape) == 2:
x = x.unsqueeze(1)
tf_rep = self.encoder(x)
proj, mask_out = self.masker(take_mag(tf_rep))
masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out)
wavs = torch_utils.pad_x_to_y(self.decoder(masked), x)
dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked,
proj=proj)
return wavs, dic_out
def forward(self, x, bg=None):
if len(x.shape) == 2:
x = x.unsqueeze(1)
tf_rep = self.encoder(x)
# Concat ReIm and Mag input
if self.learnable_scaling:
est_masks, weights = self.masker(transforms.take_cat(tf_rep))
else:
est_masks = self.masker(transforms.take_cat(tf_rep))
# Note : this is equivalent to ReIm masking for STFT
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
out_wavs = self.decoder(masked_tf_rep)
# Mixture consistency (weights are not learned but based on power)
# Estimates should sum up to the targets only
if bg is None:
return pad_x_to_y(out_wavs, x)
if len(bg.shape) == 2:
bg = bg.unsqueeze(1)
if self.learnable_scaling:
out_wavs = consistency.mixture_consistency(x - bg, out_wavs, src_weights=weights)
else:
out_wavs = consistency.mixture_consistency(x - bg, out_wavs)
return pad_x_to_y(out_wavs, x)
else:
est_masks = self.masker(transforms.take_cat(tf_rep))
# Note : this is equivalent to ReIm masking for STFT
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
out_wavs = self.decoder(masked_tf_rep)
# Mixture consistency (weights are not learned but based on power)
# Estimates should sum up to the targets only
if bg is None:
return pad_x_to_y(out_wavs, x)
if len(bg.shape) == 2:
bg = bg.unsqueeze(1)
if self.learnable_scaling:
out_wavs = consistency.mixture_consistency(x - bg, out_wavs, src_weights=weights)
else:
out_wavs = consistency.mixture_consistency(x - bg, out_wavs)
return pad_x_to_y(out_wavs, x)