How to use schnetpack - 10 common examples

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / tests / fixtures / qm9.py View on Github external
def qm9_splits(qm9_dataset, qm9_split):
    return spk.data.train_test_split(qm9_dataset, *qm9_split)
github atomistic-machine-learning / schnetpack / tests / fixtures / data.py View on Github external
def train_loader(train, batch_size):
    return spk.data.AtomsLoader(train, batch_size)
github atomistic-machine-learning / schnetpack / tests / test_data.py View on Github external
def test_loader(example_asedata, batch_size):
    loader = schnetpack.data.AtomsLoader(example_asedata, batch_size)
    for batch in loader:
        for entry in batch.values():
            assert entry.shape[0] == min(batch_size, len(loader.dataset))

    mu, std = loader.get_statistics("energy")
    assert mu["energy"] == torch.FloatTensor([5.0])
    assert std["energy"] == torch.FloatTensor([0.0])
github atomistic-machine-learning / schnetpack / tests / fixtures / data.py View on Github external
def test_loader(test, batch_size):
    return spk.data.AtomsLoader(test, batch_size)
github atomistic-machine-learning / schnetpack / tests / fixtures / qm9.py View on Github external
def qm9_dataset(qm9_dbpath):
    print(os.path.exists(qm9_dbpath))
    return QM9(qm9_dbpath)
github atomistic-machine-learning / schnetpack / tests / test_scripts_utils.py View on Github external
lr_decay=0.5,
            lr_min=1e-6,
            logger="csv",
            modelpath=modeldir,
            log_every_n_epochs=2,
            max_steps=30,
            checkpoint_interval=1,
            keep_n_checkpoints=1,
            dataset="qm9",
        )
        trainer = get_trainer(
            args, schnet, qm9_train_loader, qm9_val_loader, metrics=None
        )
        assert trainer._model == schnet
        hook_types = [type(hook) for hook in trainer.hooks]
        assert schnetpack.train.hooks.CSVHook in hook_types
        assert schnetpack.train.hooks.TensorboardHook not in hook_types
        assert schnetpack.train.hooks.MaxEpochHook in hook_types
        assert schnetpack.train.hooks.ReduceLROnPlateauHook in hook_types
github atomistic-machine-learning / schnetpack / tests / test_scripts_utils.py View on Github external
)
        mean = {args.property: None}
        model = get_model(
            args, train_loader=qm9_train_loader, mean=mean, stddev=mean, atomref=mean
        )

        os.makedirs(modeldir, exist_ok=True)
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
                )
            ],
        )
        assert os.path.exists(os.path.join(modeldir, "evaluation.txt"))
        args.split = ["train"]
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
github atomistic-machine-learning / schnetpack / tests / test_scripts_utils.py View on Github external
metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
                )
            ],
        )
        args.split = ["validation"]
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0"
                )
github atomistic-machine-learning / schnetpack / tests / test_metrics.py View on Github external
def forces_mae():
    return MeanAbsoluteError("_forces", "dydx", name="forces", element_wise=True)
github atomistic-machine-learning / schnetpack / tests / test_orca_parser.py View on Github external
def test_main_file_parser(main_path, targets_main):
    main_parser = OrcaMainFileParser(properties=OrcaMainFileParser.properties)
    main_parser.parse_file(main_path)

    results = main_parser.get_parsed()
    results[Properties.Z] = results["atoms"][0]
    results[Properties.R] = results["atoms"][1]
    results.pop("atoms", None)

    for p in targets_main:
        assert p in results

        if p == Properties.Z:
            assert np.array_equal(results[p], targets_main[p])
        else:
            assert np.allclose(results[p], targets_main[p])