How to use the chemprop.nn_utils.index_select_ND function in chemprop

To help you get started, we’ve selected a few chemprop examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wengong-jin / chemprop / chemprop / models / jtnn.py View on Github external
def forward(self,
                fnode: torch.Tensor,
                fmess: torch.Tensor,
                node_graph: torch.Tensor,
                mess_graph: torch.Tensor,
                scope: List[Tuple[int, int]]) -> torch.Tensor:
        messages = torch.zeros(mess_graph.size(0), self.hidden_size)

        if next(self.parameters()).is_cuda:
            fnode, fmess, node_graph, mess_graph, messages = fnode.cuda(), fmess.cuda(), node_graph.cuda(), mess_graph.cuda(), messages.cuda()

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, node_graph)
        fnode = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        fnode = self.outputNN(fnode)
        tree_vec = []
        for st, le in scope:
            tree_vec.append(fnode.narrow(0, st, le).mean(dim=0))

        return torch.stack(tree_vec, dim=0)
github wengong-jin / chemprop / chemprop / models / mpn.py View on Github external
for i in range(self.num_heads)]  # num_bonds x maxnb
                attention_weights = [F.softmax(attention_scores[i], dim=1)
                                     for i in range(self.num_heads)]  # num_bonds x maxnb
                message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
                                      for i in range(self.num_heads)]  # num_bonds x maxnb x hidden
                message_components = [component.sum(dim=1) for component in message_components]  # num_bonds x hidden
                message = torch.cat(message_components, dim=1)  # num_bonds x num_heads * hidden
            elif self.atom_messages:
                nei_a_message = index_select_ND(message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            for lpm in range(self.layers_per_message - 1):
                message = self.W_h[lpm][depth](message)  # num_bonds x hidden
                message = self.act_func(message)
            message = self.W_h[self.layers_per_message - 1][depth](message)

            if self.normalize_messages:
                message = message / message.norm(dim=1, keepdim=True)

            if self.master_node:
                # master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node
                # master_state = self.GRU_master(nei_message.unsqueeze(1))
                # master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore
github wengong-jin / chemprop / chemprop / models / jtnn.py View on Github external
def forward(self,
                fnode: torch.Tensor,
                fmess: torch.Tensor,
                node_graph: torch.Tensor,
                mess_graph: torch.Tensor,
                scope: List[Tuple[int, int]]) -> torch.Tensor:
        messages = torch.zeros(mess_graph.size(0), self.hidden_size)

        if next(self.parameters()).is_cuda:
            fnode, fmess, node_graph, mess_graph, messages = fnode.cuda(), fmess.cuda(), node_graph.cuda(), mess_graph.cuda(), messages.cuda()

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, node_graph)
        fnode = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        fnode = self.outputNN(fnode)
        tree_vec = []
        for st, le in scope:
            tree_vec.append(fnode.narrow(0, st, le).mean(dim=0))

        return torch.stack(tree_vec, dim=0)
github wengong-jin / chemprop / chemprop / models / mpn.py View on Github external
# TODO: Parallelize attention heads
                nei_message = index_select_ND(message, b2b)
                message = message.unsqueeze(1).repeat((1, nei_message.size(1), 1))  # num_bonds x maxnb x hidden
                attention_scores = [(self.W_ma[i](nei_message) * message).sum(dim=2)
                                    for i in range(self.num_heads)]  # num_bonds x maxnb
                attention_scores = [attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20)
                                    for i in range(self.num_heads)]  # num_bonds x maxnb
                attention_weights = [F.softmax(attention_scores[i], dim=1)
                                     for i in range(self.num_heads)]  # num_bonds x maxnb
                message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
                                      for i in range(self.num_heads)]  # num_bonds x maxnb x hidden
                message_components = [component.sum(dim=1) for component in message_components]  # num_bonds x hidden
                message = torch.cat(message_components, dim=1)  # num_bonds x num_heads * hidden
            elif self.atom_messages:
                nei_a_message = index_select_ND(message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            for lpm in range(self.layers_per_message - 1):
                message = self.W_h[lpm][depth](message)  # num_bonds x hidden
                message = self.act_func(message)
            message = self.W_h[self.layers_per_message - 1][depth](message)

            if self.normalize_messages:
github wengong-jin / chemprop / chemprop / models / mpn.py View on Github external
global_attention_mask[i, start:start + length] = 1

            if next(self.parameters()).is_cuda:
                global_attention_mask = global_attention_mask.cuda()

        # Message passing
        for depth in range(self.depth - 1):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.learn_virtual_edges:
                message = message * straight_through_mask

            if self.message_attention:
                # TODO: Parallelize attention heads
                nei_message = index_select_ND(message, b2b)
                message = message.unsqueeze(1).repeat((1, nei_message.size(1), 1))  # num_bonds x maxnb x hidden
                attention_scores = [(self.W_ma[i](nei_message) * message).sum(dim=2)
                                    for i in range(self.num_heads)]  # num_bonds x maxnb
                attention_scores = [attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20)
                                    for i in range(self.num_heads)]  # num_bonds x maxnb
                attention_weights = [F.softmax(attention_scores[i], dim=1)
                                     for i in range(self.num_heads)]  # num_bonds x maxnb
                message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
                                      for i in range(self.num_heads)]  # num_bonds x maxnb x hidden
                message_components = [component.sum(dim=1) for component in message_components]  # num_bonds x hidden
                message = torch.cat(message_components, dim=1)  # num_bonds x num_heads * hidden
            elif self.atom_messages:
                nei_a_message = index_select_ND(message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(dim=1)  # num_atoms x hidden + bond_fdim
github wengong-jin / chemprop / chemprop / models / jtnn.py View on Github external
def forward(self, smiles_batch: List[str]):
        # Get MolTrees with memoization
        mol_batch = [SMILES_TO_MOLTREE[smiles]
                     if smiles in SMILES_TO_MOLTREE else SMILES_TO_MOLTREE.setdefault(smiles, MolTree(smiles))
                     for smiles in smiles_batch]
        fnode, fmess, node_graph, mess_graph, scope = self.tensorize(mol_batch)

        if next(self.parameters()).is_cuda:
            fnode, fmess, node_graph, mess_graph = fnode.cuda(), fmess.cuda(), node_graph.cuda(), mess_graph.cuda()

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, fmess)
        tree_vec = self.jtnn((fnode, fmess, node_graph, mess_graph, scope, []))
        mol_vec = self.mpn(smiles_batch)

        return torch.cat([tree_vec, mol_vec], dim=-1)
github wengong-jin / chemprop / chemprop / models / mpn.py View on Github external
if self.master_node and self.use_master_as_output:
            assert self.hidden_size == self.master_dim
            mol_vecs = []
            for start, size in b_scope:
                if size == 0:
                    mol_vecs.append(self.cached_zero_vector)
                else:
                    mol_vecs.append(master_state[start])
            return torch.stack(mol_vecs, dim=0)

        # Get atom hidden states from message hidden states
        if self.learn_virtual_edges:
            message = message * straight_through_mask

        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(message, a2x)  # num_atoms x max_num_bonds x hidden
        a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        a_input = torch.cat([f_atoms, a_message], dim=1)  # num_atoms x (atom_fdim + hidden)
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        if self.deepset:
            atom_hiddens = self.W_s2s_a(atom_hiddens)
            atom_hiddens = self.act_func(atom_hiddens)
            atom_hiddens = self.W_s2s_b(atom_hiddens)

        if self.bert_pretraining:
            atom_preds = self.W_v(atom_hiddens)[1:]  # num_atoms x vocab/output size (leave out atom padding)

        # Readout
        if self.set2set:
            # Set up sizes
github wengong-jin / chemprop / chemprop / models / mpn.py View on Github external
if self.message_attention:
                # TODO: Parallelize attention heads
                nei_message = index_select_ND(message, b2b)
                message = message.unsqueeze(1).repeat((1, nei_message.size(1), 1))  # num_bonds x maxnb x hidden
                attention_scores = [(self.W_ma[i](nei_message) * message).sum(dim=2)
                                    for i in range(self.num_heads)]  # num_bonds x maxnb
                attention_scores = [attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20)
                                    for i in range(self.num_heads)]  # num_bonds x maxnb
                attention_weights = [F.softmax(attention_scores[i], dim=1)
                                     for i in range(self.num_heads)]  # num_bonds x maxnb
                message_components = [nei_message * attention_weights[i].unsqueeze(2).repeat((1, 1, self.hidden_size))
                                      for i in range(self.num_heads)]  # num_bonds x maxnb x hidden
                message_components = [component.sum(dim=1) for component in message_components]  # num_bonds x hidden
                message = torch.cat(message_components, dim=1)  # num_bonds x num_heads * hidden
            elif self.atom_messages:
                nei_a_message = index_select_ND(message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            for lpm in range(self.layers_per_message - 1):
                message = self.W_h[lpm][depth](message)  # num_bonds x hidden
                message = self.act_func(message)
            message = self.W_h[self.layers_per_message - 1][depth](message)