How to use the megnet.data.graph.GraphBatchGenerator function in megnet

To help you get started, we’ve selected a few megnet examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github materialsvirtuallab / megnet / megnet / models.py View on Github external
def _create_generator(self, *args, **kwargs) -> Union[GraphBatchDistanceConvert, GraphBatchGenerator]:
        if hasattr(self.graph_converter, 'bond_converter'):
            kwargs.update({'distance_converter': self.graph_converter.bond_converter})
            return GraphBatchDistanceConvert(*args, **kwargs)
        else:
            return GraphBatchGenerator(*args, **kwargs)
github materialsvirtuallab / megnet / megnet / data / graph.py View on Github external
- [ndarray]: List of indices for the start of each bond
                - [ndarray]: List of indices for the end of each bond
        """

        # Get the features and connectivity lists for this batch
        it = itemgetter(*batch_index)
        feature_list_temp = itemgetter_list(self.atom_features, batch_index)
        connection_list_temp = itemgetter_list(self.bond_features, batch_index)
        global_list_temp = itemgetter_list(self.state_features, batch_index)
        index1_temp = itemgetter_list(self.index1_list, batch_index)
        index2_temp = itemgetter_list(self.index2_list, batch_index)

        return feature_list_temp, connection_list_temp, global_list_temp, index1_temp, index2_temp


class GraphBatchDistanceConvert(GraphBatchGenerator):
    """
    Generate batch of structures with bond distance being expanded using a Expansor

    Args:
        atom_features: (list of np.array) list of atom feature matrix,
        bond_features: (list of np.array) list of bond features matrix
        state_features: (list of np.array) list of [1, G] state features, where G is the global state feature dimension
        index1_list: (list of integer) list of (M, ) one side atomic index of the bond, M is different for different
            structures
        index2_list: (list of integer) list of (M, ) the other side atomic index of the bond, M is different for
            different structures, but it has to be the same as the correponding index1.
        targets: (numpy array), N*1, where N is the number of structures
        batch_size: (int) number of samples in a batch
        is_shuffle: (bool) whether to shuffle the structure, default to True
        distance_converter: (bool) converter for processing the distances
github materialsvirtuallab / megnet / megnet / models.py View on Github external
def _create_generator(self, *args, **kwargs) -> Union[GraphBatchDistanceConvert, GraphBatchGenerator]:
        if hasattr(self.graph_converter, 'bond_converter'):
            kwargs.update({'distance_converter': self.graph_converter.bond_converter})
            return GraphBatchDistanceConvert(*args, **kwargs)
        else:
            return GraphBatchGenerator(*args, **kwargs)
github materialsvirtuallab / megnet / megnet / data / molecule.py View on Github external
def create_cached_generator(self) -> GraphBatchGenerator:
        """Generates features for all of the molecules and stores them in memory

        Returns:
            (GraphBatchGenerator) Graph genereator that relies on having the graphs in memory
        """

        # Make all the graphs
        graphs = self._generate_graphs(self.mols)

        # Turn them into a fat array
        inputs = self.converter.get_flat_data(graphs, self.targets)

        return GraphBatchGenerator(*inputs, is_shuffle=self.is_shuffle,
                                   batch_size=self.batch_size)