Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main(conf):
# from asteroid.data.toy_data import WavSet
# train_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
# val_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
# Define data pipeline
train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
train_loader = DataLoader(train_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'])
val_loader = DataLoader(val_set, shuffle=False,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'])
conf['masknet'].update({'n_src': train_set.n_src})
# Define model and optimizer in a local function (defined in the recipe).
# Two advantages to this : re-instantiating the model and optimizer
# for retraining and evaluating is straight-forward.
model, optimizer = make_model_and_optimizer(conf)
def main(conf):
train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'], segment=conf['data']['segment'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
train_loader = DataLoader(train_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
val_loader = DataLoader(val_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
# Update number of source values (It depends on the task)
conf['masknet'].update({'n_src': train_set.n_src})
def get_data_loaders(conf, train_part='filterbank'):
train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'],
normalize_audio=True)
val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'],
normalize_audio=True)
if train_part not in ['filterbank', 'separator']:
raise ValueError('Part to train: {} is not available.'.format(
train_part))
train_loader = DataLoader(train_set, shuffle=True, drop_last=True,
batch_size=conf[train_part + '_training'][
train_part[0] + '_batch_size'],
num_workers=conf[train_part + '_training'][
train_part[0] + '_num_workers'])
val_loader = DataLoader(val_set, shuffle=False, drop_last=True,
batch_size=conf[train_part + '_training'][
train_part[0] + '_batch_size'],
def main(conf):
train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'], segment=conf['data']['segment'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
train_loader = DataLoader(train_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
val_loader = DataLoader(val_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
# Update number of source values (It depends on the task)
conf['masknet'].update({'n_src': train_set.n_src})
# Define model and optimizer in a local function (defined in the recipe).
# Two advantages to this : re-instantiating the model and optimizer
def get_data_loaders(conf, train_part='filterbank'):
train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'],
normalize_audio=True)
val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'],
normalize_audio=True)
if train_part not in ['filterbank', 'separator']:
raise ValueError('Part to train: {} is not available.'.format(
train_part))
train_loader = DataLoader(train_set, shuffle=True, drop_last=True,
batch_size=conf[train_part + '_training'][
train_part[0] + '_batch_size'],
num_workers=conf[train_part + '_training'][
def main(conf):
model = load_best_separator_if_available(conf['train_conf'])
# Handle device placement
if conf['use_gpu']:
model.cuda()
model_device = next(model.parameters()).device
test_set = WhamDataset(conf['test_dir'], conf['task'],
sample_rate=conf['sample_rate'],
nondefault_nsrc=model.separator.n_sources,
segment=None, normalize_audio=True)
# Used to reorder sources only
loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
# Randomly choose the indexes of sentences to save.
ex_save_dir = os.path.join(conf['exp_dir'], 'examples/')
if conf['n_save_ex'] == -1:
conf['n_save_ex'] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
series_list = []
torch.no_grad().__enter__()
cnt = 0
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.