How to use the dgl.function.copy_edge function in dgl

To help you get started, we’ve selected a few dgl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github dmlc / dgl / examples / pytorch / line_graph / gnn.py View on Github external
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
        pmpd_x = F.embedding(pm_pd, x)

        sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))

        g.set_e_repr({'y' : y})
        g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y'))
        pmpd_y = g.pop_n_repr('pmpd_y')

        x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
        n = self.out_feats // 2
        x = th.cat([x[:, :n], F.relu(x[:, n:])], 1)
        x = self.bn_x(x)

        sum_y = sum(gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y)))

        y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum_y + self.gamma_x(pmpd_x)
        y = th.cat([y[:, :n], F.relu(y[:, n:])], 1)
        y = self.bn_y(y)

        return x, y
github dmlc / dgl / tests / compute / test_kernel.py View on Github external
def _test(red, partial):
        g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
        # NOTE(zihao): add self-loop to avoid zero-degree nodes.
        g.add_edges(g.nodes(), g.nodes())
        hu, hv, he = generate_feature(g, 'none', 'none')
        if partial:
            nid = F.tensor(list(range(0, 100, 2)))

        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
            if partial:
                g.pull(nid, fn.copy_edge(edge='e', out='m'),
                       builtin[red](msg='m', out='r1'))
            else:
                g.update_all(fn.copy_edge(edge='e', out='m'),
                             builtin[red](msg='m', out='r1'))
            r1 = g.ndata['r1']
            F.backward(F.reduce_sum(r1))
            e_grad1 = F.grad(g.edata['e'])

        # reset grad
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
            if partial:
                g.pull(nid, udf_copy_edge, udf_reduce[red])
github dmlc / dgl / tests / compute / test_function.py View on Github external
def test_copy_edge():
    # copy_edge with both fields
    g = generate_graph()
    g.register_message_func(fn.copy_edge(edge='h', out='m'))
    g.register_reduce_func(reducer_both)
    # test with update_all
    g.update_all()
    assert F.allclose(g.ndata.pop('out'),
            F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
    # test with send and then recv
    g.send()
    g.recv()
    assert F.allclose(g.ndata.pop('out'),
            F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
github dmlc / dgl / examples / pytorch / jtnn / jtnn / mpn.py View on Github external
class LoopyBPUpdate(nn.Module):
    def __init__(self, hidden_size):
        super(LoopyBPUpdate, self).__init__()
        self.hidden_size = hidden_size

        self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, nodes):
        msg_input = nodes.data['msg_input']
        msg_delta = self.W_h(nodes.data['accum_msg'])
        msg = F.relu(msg_input + msg_delta)
        return {'msg': msg}


mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')


class GatherUpdate(nn.Module):
    def __init__(self, hidden_size):
        super(GatherUpdate, self).__init__()
        self.hidden_size = hidden_size

        self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)

    def forward(self, nodes):
        m = nodes.data['m']
        return {
            'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
        }
github dmlc / dgl / python / dgl / model_zoo / chem / jtnn / jtnn_dec.py View on Github external
from .nnutils import GRUUpdate, cuda

MAX_NB = 8
MAX_DECODE_LEN = 100


def dfs_order(forest, roots):
    edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
    for e, l in zip(*edges):
        # I exploited the fact that the reverse edge ID equal to 1 xor forward
        # edge ID for molecule trees.  Normally, I should locate reverse edges
        # using find_edges().
        yield e ^ l, l


dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h')


def dec_tree_node_update(nodes):
    return {'new': nodes.data['new'].clone().zero_()}


dec_tree_edge_msg = [DGLF.copy_src(
    src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [
    DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]


def have_slots(fa_slots, ch_slots):
    if len(fa_slots) > 2 and len(ch_slots) > 2:
        return True
github dmlc / dgl / examples / pytorch / jtnn / jtnn / jtmpn.py View on Github external
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, node):
        msg_input = node.data['msg_input']
        msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha'])
        msg = torch.relu(msg_input + msg_delta)
        return {'msg': msg}


if PAPER:
    mpn_gather_msg = [
        DGLF.copy_edge(edge='msg', out='msg'),
        DGLF.copy_edge(edge='alpha', out='alpha')
    ]
else:
    mpn_gather_msg = DGLF.copy_edge(edge='msg', out='msg')


if PAPER:
    mpn_gather_reduce = [
        DGLF.sum(msg='msg', out='m'),
        DGLF.sum(msg='alpha', out='accum_alpha'),
    ]
else:
    mpn_gather_reduce = DGLF.sum(msg='msg', out='m')


class GatherUpdate(nn.Module):
    def __init__(self, hidden_size):
        super(GatherUpdate, self).__init__()
        self.hidden_size = hidden_size
github dmlc / dgl / examples / pytorch / transformer / modules / models.py View on Github external
def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])
github dmlc / dgl / examples / pytorch / jtnn / jtnn / jtnn_dec.py View on Github external
import dgl.function as DGLF
import numpy as np

MAX_NB = 8
MAX_DECODE_LEN = 100


def dfs_order(forest, roots):
    edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
    for e, l in zip(*edges):
        # I exploited the fact that the reverse edge ID equal to 1 xor forward
        # edge ID for molecule trees.  Normally, I should locate reverse edges
        # using find_edges().
        yield e ^ l, l

dec_tree_node_msg = DGLF.copy_edge(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h')


def dec_tree_node_update(nodes):
    return {'new': nodes.data['new'].clone().zero_()}


dec_tree_edge_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
dec_tree_edge_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]


def have_slots(fa_slots, ch_slots):
    if len(fa_slots) > 2 and len(ch_slots) > 2:
        return True
    matches = []
    for i,s1 in enumerate(fa_slots):
github dmlc / dgl / examples / pytorch / jtnn / jtnn / jtnn_enc.py View on Github external
from dgl import batch, bfs_edges_generator
import dgl.function as DGLF
import numpy as np

MAX_NB = 8

def level_order(forest, roots):
    edges = bfs_edges_generator(forest, roots)
    _, leaves = forest.find_edges(edges[-1])
    edges_back = bfs_edges_generator(forest, roots, reverse=True)
    yield from reversed(edges_back)
    yield from edges

enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')

class EncoderGatherUpdate(nn.Module):
    def __init__(self, hidden_size):
        nn.Module.__init__(self)
        self.hidden_size = hidden_size

        self.W = nn.Linear(2 * hidden_size, hidden_size)

    def forward(self, nodes):
        x = nodes.data['x']
        m = nodes.data['m']
        return {
            'h': torch.relu(self.W(torch.cat([x, m], 1))),
        }