How to use the schnetpack.environment.SimpleEnvironmentProvider function in schnetpack

To help you get started, we’ve selected a few schnetpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github atomistic-machine-learning / schnetpack / src / schnetpack / data / atoms.py View on Github external
def __init__(
        self,
        dbpath,
        subset=None,
        available_properties=None,
        load_only=None,
        units=None,
        environment_provider=SimpleEnvironmentProvider(),
        collect_triples=False,
        center_positions=True,
    ):
        if not dbpath.endswith(".db"):
            raise AtomsDataError(
                "Invalid dbpath! Please make sure to add the file extension '.db' to "
                "your dbpath."
            )

        self.dbpath = dbpath
        self.subset = subset
        self.load_only = load_only
        self.available_properties = self.get_available_properties(available_properties)
        if load_only is None:
            self.load_only = self.available_properties
        if units is None:
github atomistic-machine-learning / schnetpack / src / schnetpack / datasets / qm9.py View on Github external
def __init__(
        self,
        dbpath,
        download=True,
        subset=None,
        load_only=None,
        collect_triples=False,
        remove_uncharacterized=False,
        environment_provider=spk.environment.SimpleEnvironmentProvider(),
        **kwargs
    ):

        self.remove_uncharacterized = remove_uncharacterized

        available_properties = [
            QM9.A,
            QM9.B,
            QM9.C,
            QM9.mu,
            QM9.alpha,
            QM9.homo,
            QM9.lumo,
            QM9.gap,
            QM9.r2,
            QM9.zpve,
github atomistic-machine-learning / schnetpack / src / schnetpack / datasets / ani1.py View on Github external
def __init__(
        self,
        dbpath,
        download=True,
        subset=None,
        load_only=None,
        collect_triples=False,
        num_heavy_atoms=8,
        high_energies=False,
        environment_provider=spk.environment.SimpleEnvironmentProvider(),
    ):
        available_properties = [ANI1.energy]
        units = [Hartree]

        self.num_heavy_atoms = num_heavy_atoms
        self.high_energies = high_energies

        super().__init__(
            dbpath=dbpath,
            subset=subset,
            download=download,
            load_only=load_only,
            collect_triples=collect_triples,
            available_properties=available_properties,
            units=units,
            environment_provider=environment_provider,
github atomistic-machine-learning / schnetpack / src / schnetpack / datasets / iso17.py View on Github external
def __init__(
        self,
        datapath,
        fold,
        download=True,
        load_only=None,
        subset=None,
        collect_triples=False,
        environment_provider=spk.environment.SimpleEnvironmentProvider(),
    ):

        if fold not in self.existing_folds:
            raise ValueError("Fold {:s} does not exist".format(fold))

        available_properties = [ISO17.E, ISO17.F]
        units = [1.0, 1.0]

        self.path = datapath
        self.fold = fold
        dbpath = os.path.join(datapath, "iso17", fold + ".db")

        super().__init__(
            dbpath=dbpath,
            subset=subset,
            load_only=load_only,
github atomistic-machine-learning / schnetpack / src / schnetpack / interfaces / ase_interface.py View on Github external
def __init__(
        self,
        model,
        device="cpu",
        collect_triples=False,
        environment_provider=SimpleEnvironmentProvider(),
        energy=None,
        forces=None,
        energy_units="eV",
        forces_units="eV/Angstrom",
        **kwargs
    ):
        Calculator.__init__(self, **kwargs)

        self.model = model

        self.atoms_converter = AtomsConverter(
            environment_provider=environment_provider,
            collect_triples=collect_triples,
            device=device,
        )
github atomistic-machine-learning / schnetpack / src / schnetpack / datasets / matproj.py View on Github external
def __init__(
        self,
        dbpath,
        apikey=None,
        download=True,
        subset=None,
        load_only=None,
        collect_triples=False,
        environment_provider=spk.environment.SimpleEnvironmentProvider(),
    ):

        available_properties = [
            MaterialsProject.EformationPerAtom,
            MaterialsProject.EPerAtom,
            MaterialsProject.BandGap,
            MaterialsProject.TotalMagnetization,
        ]

        units = [eV, eV, eV, 1.0]

        self.apikey = apikey

        super(MaterialsProject, self).__init__(
            dbpath=dbpath,
            subset=subset,
github atomistic-machine-learning / schnetpack / src / schnetpack / datasets / omdb.py View on Github external
def __init__(
        self,
        path,
        download=True,
        subset=None,
        load_only=None,
        collect_triples=False,
        environment_provider=spk.environment.SimpleEnvironmentProvider(),
    ):
        available_properties = [OrganicMaterialsDatabase.BandGap]

        units = [eV]

        self.path = path

        dbpath = self.path.replace(".tar.gz", ".db")
        self.dbpath = dbpath

        if not os.path.exists(path) and not os.path.exists(dbpath):
            raise FileNotFoundError(
                "Download OMDB dataset (e.g. OMDB-GAP1.tar.gz) from "
                "https://omdb.diracmaterials.org/dataset/ and set datapath to this file"
            )
github atomistic-machine-learning / schnetpack / src / schnetpack / data / atoms.py View on Github external
def _convert_atoms(
    atoms,
    environment_provider=SimpleEnvironmentProvider(),
    collect_triples=False,
    center_positions=False,
    output=None,
):
    """
        Helper function to convert ASE atoms object to SchNetPack input format.

        Args:
            atoms (ase.Atoms): Atoms object of molecule
            environment_provider (callable): Neighbor list provider.
            device (str): Device for computation (default='cpu')
            output (dict): Destination for converted atoms, if not None

    Returns:
        dict of torch.Tensor: Properties including neighbor lists and masks
            reformated into SchNetPack input format.