How to use the chemprop.train.cross_validate function in chemprop

To help you get started, we’ve selected a few chemprop examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wengong-jin / chemprop / model_comparison.py View on Github external
args.dataset_type = dataset_type
        args.save_dir = os.path.join(args.save_dir, dataset_name)
        args.num_folds = num_folds
        args.metric = metric
        if features_dir is not None:
            args.features_path = [os.path.join(features_dir, dataset_name + '.pckl')]
        modify_train_args(args)

        # Set up logging for training
        os.makedirs(args.save_dir, exist_ok=True)
        fh = logging.FileHandler(os.path.join(args.save_dir, args.log_name))
        fh.setLevel(logging.DEBUG)

        # Cross validate
        TRAIN_LOGGER.addHandler(fh)
        mean_score, std_score = cross_validate(args, TRAIN_LOGGER)
        TRAIN_LOGGER.removeHandler(fh)

        # Record results
        logger.info(f'{mean_score} +/- {std_score} {metric}')
        temp_model = build_model(args)
        logger.info(f'num params: {param_count(temp_model):,}')
github wengong-jin / chemprop / end_to_end.py View on Github external
optimize_hyperparameters(args)

    # Determine best hyperparameters, update args, and train
    results = load_sorted_results(args.results_dir)
    config = results[0]
    config.pop('loss')
    print('Best config')
    pprint(config)
    for key, value in config.items():
        setattr(args, key, value)

    args.data_path = args.train_val_save
    args.separate_test_set = None
    args.split_sizes = [0.8, 0.2, 0.0]  # no need for a test set during training

    cross_validate(args, logger)

    # Predict on test data
    args.checkpoint_dir = args.save_dir
    update_args_from_checkpoint_dir(args)
    args.compound_names = True  # only if test set has compound names
    args.ensemble_size = 5  # might want to make this an arg somehow (w/o affecting hyperparameter optimization)

    make_predictions(args)
github wengong-jin / chemprop / hyperparameter_optimization.py View on Github external
for key in INT_KEYS:
            hyperparams[key] = int(hyperparams[key])

        # Update args with hyperparams
        hyper_args = deepcopy(args)
        if args.save_dir is not None:
            folder_name = '_'.join([f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}' for key, value in hyperparams.items()])
            hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name)
        for key, value in hyperparams.items():
            setattr(hyper_args, key, value)

        # Record hyperparameters
        logger.info(hyperparams)

        # Cross validate
        mean_score, std_score = cross_validate(hyper_args, train_logger)

        # Record results
        temp_model = build_model(hyper_args)
        num_params = param_count(temp_model)
        logger.info(f'num params: {num_params:,}')
        logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')

        results.append({
            'mean_score': mean_score,
            'std_score': std_score,
            'hyperparams': hyperparams,
            'num_params': num_params
        })

        # Deal with nan
        if np.isnan(mean_score):
github wengong-jin / chemprop / train.py View on Github external
import logging

from chemprop.parsing import parse_train_args
from chemprop.train import cross_validate
from chemprop.utils import set_logger


# Initialize logger
logger = logging.getLogger('train')
logger.setLevel(logging.DEBUG)
logger.propagate = False

if __name__ == '__main__':
    args = parse_train_args()
    set_logger(logger, args.save_dir, args.quiet)
    cross_validate(args, logger)