Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Args:
estimate (torch.Tensor): Estimate complex spectrogram.
target (torch.Tensor): Speech target complex spectrogram.
is_complex (bool): Whether to compute the distance in the complex or
the magnitude space.
Returns:
torch.Tensor the loss value, in a tensor of size 1.
"""
if is_complex:
# Take the difference in the complex plane and compute the squared norm
# of the remaining vector.
return take_mag(estimate - target).pow(2).mean()
else:
# Compute the mean difference between magnitudes.
return (take_mag(estimate) - take_mag(target)).pow(2).mean()
def common_step(self, batch, batch_nb, train=False):
inputs, targets, masks = self.unpack_data(batch)
embeddings, est_masks = self(inputs)
spec = take_mag(self.model.encoder(inputs.unsqueeze(1)))
if self.mask_mixture:
est_masks = est_masks * spec.unsqueeze(1)
masks = masks * spec.unsqueeze(1)
loss, loss_dic = self.loss_func(embeddings, targets, est_src=est_masks,
target_src=masks, mix_spec=spec)
return loss, loss_dic
def unpack_data(self, batch):
mix, sources, noise = batch
# Take only the first channel
mix = mix[..., 0]
sources = sources[...,0]
noise = noise[..., 0]
noise = noise.unsqueeze(1)
# Compute magnitude spectrograms and IRM
src_mag_spec = take_mag(self.model.encoder(sources))
noise_mag_spec = take_mag(self.model.encoder(noise))
noise_mag_spec = noise_mag_spec.unsqueeze(1)
real_mask = src_mag_spec / ( noise_mag_spec+src_mag_spec.sum(1, keepdim=True) + EPS)
# Get the src idx having the maximum energy
binary_mask = real_mask.argmax(1)
return mix, binary_mask, real_mask
def unpack_data(self, batch):
mix, sources = batch
# Compute magnitude spectrograms and IRM
src_mag_spec = take_mag(self.model.encoder(sources))
real_mask = src_mag_spec / (src_mag_spec.sum(1, keepdim=True) + EPS)
# Get the src idx having the maximum energy
binary_mask = real_mask.argmax(1)
return mix, binary_mask, real_mask
def unpack_data(self, batch):
mix, sources, noise = batch
# Take only the first channel
mix = mix[..., 0]
sources = sources[...,0]
noise = noise[..., 0]
noise = noise.unsqueeze(1)
# Compute magnitude spectrograms and IRM
src_mag_spec = take_mag(self.model.encoder(sources))
noise_mag_spec = take_mag(self.model.encoder(noise))
noise_mag_spec = noise_mag_spec.unsqueeze(1)
real_mask = src_mag_spec / ( noise_mag_spec+src_mag_spec.sum(1, keepdim=True) + EPS)
# Get the src idx having the maximum energy
binary_mask = real_mask.argmax(1)
return mix, binary_mask, real_mask
def common_step(self, batch, batch_nb, train=False):
inputs, targets, masks = self.unpack_data(batch)
embeddings, est_masks = self(inputs)
spec = take_mag(self.model.encoder(inputs.unsqueeze(1)))
if self.mask_mixture:
est_masks = est_masks * spec.unsqueeze(1)
masks = masks * spec.unsqueeze(1)
loss, loss_dic = self.loss_func(embeddings, targets, est_src=est_masks,
target_src=masks, mix_spec=spec)
return loss, loss_dic