How to use the asteroid.data.librimix_dataset.LibriMix function in asteroid

To help you get started, we’ve selected a few asteroid examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mpariente / AsSteroid / egs / librimix / ConvTasNet / eval.py View on Github external
def main(conf):
    model_path = os.path.join(conf['exp_dir'], 'best_model.pth')
    model = ConvTasNet.from_pretrained(model_path)
    # Handle device placement
    if conf['use_gpu']:
        model.cuda()
    model_device = next(model.parameters()).device
    test_set = LibriMix(csv_dir=conf['test_dir'],
                        task=conf['task'],
                        sample_rate=conf['sample_rate'],
                        n_src=conf['train_conf']['data']['n_src'],
                        segment=None)  # Uses all segment length
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')

    # Randomly choose the indexes of sentences to save.
    eval_save_dir = os.path.join(conf['exp_dir'], conf['out_dir'])
    ex_save_dir = os.path.join(eval_save_dir, 'examples/')
    if conf['n_save_ex'] == -1:
        conf['n_save_ex'] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
github mpariente / AsSteroid / egs / libri_2_mix / ConvTasNet / eval.py View on Github external
checkpoint = torch.load(best_model_path, map_location='cpu')
    state = checkpoint['state_dict']
    state_copy = state.copy()
    # Remove unwanted keys
    for keys, values in state.items():
        if keys.startswith('loss'):
            del state_copy[keys]
            print(keys)
    model = load_state_dict_in(state_copy, model)

    # Handle device placement
    if conf['use_gpu']:
        model.cuda()
    model_device = next(model.parameters()).device

    test_set = LibriMix(conf['test_dir'], None,
                        conf['sample_rate'],
                        conf['train_conf']['data']['n_src'])

    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')

    # Randomly choose the indexes of sentences to save.
    ex_save_dir = os.path.join(conf['exp_dir'], 'examples_mss_8K/')
    if conf['n_save_ex'] == -1:
        conf['n_save_ex'] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
github mpariente / AsSteroid / egs / libri_2_mix / ConvTasNet / train.py View on Github external
def main(conf):
    train_set = LibriMix(conf['data']['metadata_train_path'],
                         conf['data']['desired_length'],
                         conf['data']['sample_rate'],
                         conf['data']['n_src'])

    val_set = LibriMix(conf['data']['metadata_val_path'],
                       conf['data']['desired_length'],
                       conf['data']['sample_rate'],
                       conf['data']['n_src'])

    train_loader = DataLoader(train_set, shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set, shuffle=True,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
github mpariente / AsSteroid / egs / libri_2_mix / ConvTasNet / train.py View on Github external
def main(conf):
    train_set = LibriMix(conf['data']['metadata_train_path'],
                         conf['data']['desired_length'],
                         conf['data']['sample_rate'],
                         conf['data']['n_src'])

    val_set = LibriMix(conf['data']['metadata_val_path'],
                       conf['data']['desired_length'],
                       conf['data']['sample_rate'],
                       conf['data']['n_src'])

    train_loader = DataLoader(train_set, shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set, shuffle=True,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)

    conf['masknet'].update({'n_src': 2})

    # Define model and optimizer in a local function (defined in the recipe).