How to use the asteroid.torch_utils.load_state_dict_in function in asteroid

To help you get started, we’ve selected a few asteroid examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mpariente / AsSteroid / egs / fuss / baseline / model.py View on Github external
exp_dir(str): Experiment directory. Expects to find
            `'best_k_models.json'` there.

    Returns:
        nn.Module the best pretrained model according to the val_loss.
    """
    # Create the model from recipe-local function
    model, _ = make_model_and_optimizer(train_conf)
    # Last best model summary
    with open(os.path.join(exp_dir, 'best_k_models.json'), "r") as f:
        best_k = json.load(f)
    best_model_path = min(best_k, key=best_k.get)
    # Load checkpoint
    checkpoint = torch.load(best_model_path, map_location='cpu')
    # Load state_dict into model.
    model = torch_utils.load_state_dict_in(checkpoint['state_dict'],
                                           model)
    model.eval()
    return model
github mpariente / AsSteroid / egs / libri_2_mix / ConvTasNet / eval.py View on Github external
# Make the model
    model, _ = make_model_and_optimizer(conf['train_conf'])
    # Load best model
    with open(os.path.join(conf['exp_dir'], 'best_k_models.json'), "r") as f:
        best_k = json.load(f)
    best_model_path = min(best_k, key=best_k.get)
    # Load checkpoint
    checkpoint = torch.load(best_model_path, map_location='cpu')
    state = checkpoint['state_dict']
    state_copy = state.copy()
    # Remove unwanted keys
    for keys, values in state.items():
        if keys.startswith('loss'):
            del state_copy[keys]
            print(keys)
    model = load_state_dict_in(state_copy, model)

    # Handle device placement
    if conf['use_gpu']:
        model.cuda()
    model_device = next(model.parameters()).device

    test_set = LibriMix(conf['test_dir'], None,
                        conf['sample_rate'],
                        conf['train_conf']['data']['n_src'])

    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')

    # Randomly choose the indexes of sentences to save.
    ex_save_dir = os.path.join(conf['exp_dir'], 'examples_mss_8K/')
    if conf['n_save_ex'] == -1:
        conf['n_save_ex'] = len(test_set)
github mpariente / AsSteroid / egs / wham / TwoStep / model.py View on Github external
if os.path.exists(checkpoint_dir):
        available_models = [p for p in os.listdir(checkpoint_dir)
                            if '.ckpt' in p]
        if available_models:
            model_available = True

    if not model_available:
        raise FileNotFoundError('There is no available separator model at: {}'
                                ''.format(checkpoint_dir))

    model_path = os.path.join(checkpoint_dir, available_models[0])
    print('Going to load from: {}'.format(model_path))
    checkpoint = torch.load(model_path, map_location='cpu')
    model_c, _ = make_model_and_optimizer(conf, model_part='separator',
                                          pretrained_filterbank=filterbank)
    model = torch_utils.load_state_dict_in(checkpoint['state_dict'], model_c)
    print('Successfully loaded separator from: {}'.format(model_path))
    return model