How to use the chemprop.utils.load_checkpoint function in chemprop

To help you get started, we’ve selected a few chemprop examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wengong-jin / chemprop / chemprop / train / make_predictions.py View on Github external
test_smiles = test_data.smiles()

    if args.compound_names:
        compound_names = test_data.compound_names()
    print(f'Test size = {len(test_data):,}')

    # Normalize features
    if train_args.features_scaling:
        test_data.normalize_features(features_scaler)

    # Predict with each model individually and sum predictions
    sum_preds = np.zeros((len(test_data), args.num_tasks))
    print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models')
    for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)):
        # Load model
        model = load_checkpoint(checkpoint_path, cuda=args.cuda)
        model_preds = predict(
            model=model,
            data=test_data,
            args=args,
            scaler=scaler
        )
        sum_preds += np.array(model_preds)

    # Ensemble predictions
    avg_preds = sum_preds / args.ensemble_size
    avg_preds = avg_preds.tolist()

    # Save predictions
    assert len(test_data) == len(avg_preds)
    print(f'Saving predictions to {args.preds_path}')
github wengong-jin / chemprop / scripts / viz_attention.py View on Github external
def visualize_attention(args: Namespace):
    """Visualizes attention weights."""
    print('Loading data')
    data = get_data(args.data_path)
    smiles = data.smiles()
    print('Data size = {:,}'.format(len(smiles)))

    print('Loading model from "{}"'.format(args.checkpoint_path))
    model, _, _, _ = load_checkpoint(args.checkpoint_path, cuda=args.cuda)
    mpn = model[0]

    for i in trange(0, len(smiles), args.batch_size):
        smiles_batch = smiles[i:i + args.batch_size]
        mpn.viz_attention(smiles_batch, viz_dir=args.viz_dir)
github wengong-jin / chemprop / chemprop / train / run_training.py View on Github external
debug(f'Validation {task_name} {args.metric} = {val_score:.6f}')
                        writer.add_scalar(f'validation_{task_name}_{args.metric}', val_score, n_iter)

            # Save model checkpoint if improved validation score, or always save it if unsupervised
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score or \
                    args.dataset_type == 'unsupervised':
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)

        if args.dataset_type == 'unsupervised':
            return [0]  # rest of this is meaningless when unsupervised            

        # Evaluate on test set using model with best validation score
        info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
        model = load_checkpoint(os.path.join(save_dir, 'model.pt'), cuda=args.cuda, logger=logger)

        if args.split_test_by_overlap_dataset is not None:
            overlap_data = get_data(path=args.split_test_by_overlap_dataset, logger=logger)
            overlap_smiles = set(overlap_data.smiles())
            test_data_intersect, test_data_nonintersect = [], []
            for d in test_data.data:
                if d.smiles in overlap_smiles:
                    test_data_intersect.append(d)
                else:
                    test_data_nonintersect.append(d)
            test_data_intersect, test_data_nonintersect = MoleculeDataset(test_data_intersect), MoleculeDataset(test_data_nonintersect)
            for name, td in [('Intersect', test_data_intersect), ('Nonintersect', test_data_nonintersect)]:
                test_preds = predict(
                    model=model,
                    data=td,
                    args=args,
github wengong-jin / chemprop / scripts / viz_attention.py View on Github external
def visualize_attention(args: Namespace):
    """Visualizes attention weights."""
    print('Loading data')
    data = get_data(path=args.data_path)
    smiles = data.smiles()
    print(f'Data size = {len(smiles):,}')

    print(f'Loading model from "{args.checkpoint_path}"')
    model = load_checkpoint(args.checkpoint_path, cuda=args.cuda)
    mpn = model[0]

    for i in trange(0, len(smiles), args.batch_size):
        smiles_batch = smiles[i:i + args.batch_size]
        mpn.viz_attention(smiles_batch, viz_dir=args.viz_dir)
github wengong-jin / chemprop / chemprop / train / run_training.py View on Github external
if args.dataset_type == 'bert_pretraining':
        # Only predict targets that are masked out
        test_targets['vocab'] = [target if mask == 0 else None for target, mask in zip(test_targets['vocab'], test_data.mask())]

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        os.makedirs(save_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}')
            model = load_checkpoint(args.checkpoint_paths[model_idx], current_args=args, logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = build_model(args)

        debug(model)
        debug(f'Number of parameters = {param_count(model):,}')
        if args.cuda:
            debug('Moving model to cuda')
            model = model.cuda()

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)

        if args.adjust_weight_decay:
            args.pnorm_target = compute_pnorm(model)
github wengong-jin / chemprop / scripts / visualize_encoding_property_space.py View on Github external
def visualize_encoding_property_space(args: Namespace):
    # Load data
    data = get_data(path=args.data_path)

    # Sort according to similarity measure
    if args.similarity_measure == 'property':
        data.sort(key=lambda d: d.targets[args.task_index])
    elif args.similarity_measure == 'random':
        data.shuffle(args.seed)
    else:
        raise ValueError(f'similarity_measure "{args.similarity_measure}" not supported or not implemented yet.')

    # Load model and scalers
    model = load_checkpoint(args.checkpoint_path)
    scaler, features_scaler = load_scalers(args.checkpoint_path)
    data.normalize_features(features_scaler)

    # Random seed
    if args.seed is not None:
        random.seed(args.seed)

    # Generate visualizations
    for i in trange(args.num_examples):
        # Get random three molecules with similar properties
        index = random.randint(1, len(data) - 2)
        molecules = MoleculeDataset(data[index - 1:index + 2])
        molecule_targets = [t[args.task_index] for t in molecules.targets()]

        # Encode three molecules
        molecule_encodings = model.encoder(molecules.smiles())