Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft,
hop_length=hop_length,
power=2)
out_torch = spect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
# test s2db
db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
def test_melspectrogram_load_save(self):
waveform = self.waveform.float()
mel_spectrogram_transform = transforms.MelSpectrogram()
mel_spectrogram_transform(waveform)
mel_spectrogram_transform_copy = transforms.MelSpectrogram()
mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())
window = mel_spectrogram_transform.spectrogram.window
window_copy = mel_spectrogram_transform_copy.spectrogram.window
fb = mel_spectrogram_transform.mel_scale.fb
fb_copy = mel_spectrogram_transform_copy.mel_scale.fb
self.assertTrue(torch.allclose(window, window_copy))
# the default for n_fft = 400 and n_mels = 128
self.assertEqual(fb_copy.size(), (201, 128))
self.assertTrue(torch.allclose(fb, fb_copy))
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MelSpectrogram)
waveform = self.waveform.clone() # (1, 16000)
waveform_scaled = self.scale(waveform) # (1, 16000)
mel_transform = transforms.MelSpectrogram()
# check defaults
spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
# check options
kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513)
self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
# check on multi-channel audio
x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
# check filterbank matrix creation
fb_matrix_transform = transforms.MelScale(
n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.MelSpectrogram, tensor)
def test_melspectrogram_load_save(self):
waveform = self.waveform.float()
mel_spectrogram_transform = transforms.MelSpectrogram()
mel_spectrogram_transform(waveform)
mel_spectrogram_transform_copy = transforms.MelSpectrogram()
mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())
window = mel_spectrogram_transform.spectrogram.window
window_copy = mel_spectrogram_transform_copy.spectrogram.window
fb = mel_spectrogram_transform.mel_scale.fb
fb_copy = mel_spectrogram_transform_copy.mel_scale.fb
self.assertTrue(torch.allclose(window, window_copy))
# the default for n_fft = 400 and n_mels = 128
self.assertEqual(fb_copy.size(), (201, 128))
self.assertTrue(torch.allclose(fb, fb_copy))
def test_mel2(self):
top_db = 80.
s2db = transforms.AmplitudeToDB('power', top_db)
waveform = self.waveform.clone() # (1, 16000)
waveform_scaled = self.scale(waveform) # (1, 16000)
mel_transform = transforms.MelSpectrogram()
# check defaults
spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
# check options
kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513)
self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
# TODO See https://github.com/pytorch/audio/issues/165
class Spectrogram:
forward = torchaudio.transforms.Spectrogram().forward
class AmplitudeToDB:
forward = torchaudio.transforms.AmplitudeToDB().forward
class MelScale:
forward = torchaudio.transforms.MelScale().forward
class MelSpectrogram:
forward = torchaudio.transforms.MelSpectrogram().forward
class MFCC:
forward = torchaudio.transforms.MFCC().forward
class MuLawEncoding:
forward = torchaudio.transforms.MuLawEncoding().forward
class MuLawDecoding:
forward = torchaudio.transforms.MuLawDecoding().forward
class Resample:
# Resample isn't a script_method
def __init__(self, sample_rate: int, mel_size: int, n_fft: int, win_length: int,
hop_length: int, min_db: float, max_db: float,
mel_min: float = 0., mel_max: float = None):
super().__init__()
self.mel_size = mel_size
# db to log
self.min_db = np.log(np.power(10, min_db / 10))
self.max_db = np.log(np.power(10, max_db / 10))
self.melfunc = MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
hop_length=hop_length, f_min=mel_min, f_max=mel_max, n_mels=mel_size,
window_fn=torch.hann_window)