Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def convergence_callback(Y, **kwargs):
global SDR, SIR, ref
from mir_eval.separation import bss_eval_sources
if Y.shape[2] == 1:
y = pra.transform.synthesis(
Y[:, :, 0], framesize, framesize // 2, win=win_s
)[:, None]
else:
y = pra.transform.synthesis(Y, framesize, framesize // 2, win=win_s)
if args.algo != "blinkiva":
new_ord = np.argsort(np.std(y, axis=0))[::-1]
y = y[:, new_ord]
m = np.minimum(y.shape[0] - framesize // 2, ref.shape[1])
sdr, sir, sar, perm = bss_eval_sources(
ref[:n_sources_target, :m, 0],
y[framesize // 2 : m + framesize // 2, :n_sources_target].T,
)
SDR.append(sdr)
SIR.append(sir)
def convergence_callback(Y, n_targets, SDR, SIR, ref, framesize, win_s, algo_name):
from mir_eval.separation import bss_eval_sources
if Y.shape[2] == 1:
y = pra.transform.synthesis(
Y[:, :, 0], framesize, framesize // 2, win=win_s
)[:, None]
else:
y = pra.transform.synthesis(Y, framesize, framesize // 2, win=win_s)
if algo_name not in parameters["overdet_algos"]:
new_ord = np.argsort(np.std(y, axis=0))[::-1]
y = y[:, new_ord]
m = np.minimum(y.shape[0] - framesize // 2, ref.shape[1])
synth[:n_targets, :m, 0] = y[framesize // 2 : m + framesize // 2, :n_targets].T
sdr, sir, sar, perm = bss_eval_sources(
ref[:n_targets+1, :m, 0], synth[:, :m, 0]
)
SDR.append(sdr[:n_targets].tolist())
SIR.append(sir[:n_targets].tolist())
def convergence_callback(Y, n_targets, SDR, SIR, ref, framesize, win_s, algo_name):
from mir_eval.separation import bss_eval_sources
if Y.shape[2] == 1:
y = pra.transform.synthesis(
Y[:, :, 0], framesize, framesize // 2, win=win_s
)[:, None]
else:
y = pra.transform.synthesis(Y, framesize, framesize // 2, win=win_s)
if algo_name not in parameters["overdet_algos"]:
new_ord = np.argsort(np.std(y, axis=0))[::-1]
y = y[:, new_ord]
m = np.minimum(y.shape[0] - framesize // 2, ref.shape[1])
synth[:n_targets, :m, 0] = y[framesize // 2 : m + framesize // 2, :n_targets].T
sdr, sir, sar, perm = bss_eval_sources(
ref[:n_targets+1, :m, 0], synth[:, :m, 0]
)
callback=convergence_callback)
elif bss_type == 'sparseauxiva':
# Estimate set of active frequency bins
ratio = 0.35
average = np.abs(np.mean(np.mean(X, axis=2), axis=0))
k = np.int_(average.shape[0] * ratio)
S = np.sort(np.argpartition(average, -k)[-k:])
# Run SparseAuxIva
Y = pra.bss.sparseauxiva(X, S, n_iter=30, proj_back=True,
callback=convergence_callback)
t_end = time.perf_counter()
print("Time for BSS: {:.2f} s".format(t_end - t_begin))
## STFT Synthesis
y = pra.transform.synthesis(Y, L, hop, win=win_s)
## Compare SDR and SIR
y = y[L-hop:, :].T
m = np.minimum(y.shape[1], ref.shape[1])
sdr, sir, sar, perm = bss_eval_sources(ref[:, :m, 0], y[:, :m])
print('SDR:', sdr)
print('SIR:', sir)
## PLOT RESULTS
import matplotlib.pyplot as plt
plt.figure()
plt.subplot(2,2,1)
plt.specgram(ref[0,:,0], NFFT=1024, Fs=room.fs)
plt.title('Source 0 (clean)')
plt.subplot(2,2,2)