Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _test_lfilter_basic(self, dtype, device):
"""
Create a very basic signal,
Then make a simple 4th order delay
The output should be same as the input but shifted
"""
torch.random.manual_seed(42)
waveform = torch.rand(2, 44100 * 1, dtype=dtype, device=device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=dtype, device=device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert torch.allclose(waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5)
device=device,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=device,
)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
b2 = 0.9
a0 = 0.7
a1 = 0.2
a2 = 0.6
# SoX method
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(fn_sine)
_timing_sox = time.time()
E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2])
waveform_sox_out, sr = E.sox_build_flow_effects()
_timing_sox_run_time = time.time() - _timing_sox
_timing_lfilter_filtering = time.time()
waveform, sample_rate = torchaudio.load(fn_sine, normalization=True)
waveform_lfilter_out = F.lfilter(
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
_timing_lfilter_run_time = time.time() - _timing_lfilter_filtering
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
_test_torchscript_functional(
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=device,
)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)