How to use the cogdl.utils.add_self_loops function in cogdl

To help you get started, we’ve selected a few cogdl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github THUDM / cogdl / cogdl / modules / conv / aggregator.py View on Github external
def norm(edge_index, num_nodes, edge_weight, gcn=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones(
                (edge_index.size(1), ), dtype=dtype, device=edge_index.device)
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        edge_index, _ = remove_self_loops(edge_index)
        edge_index = add_self_loops(edge_index, num_nodes)
        loop_weight = torch.full(
            (num_nodes, ),
            1 if gcn else 0,
            dtype=edge_weight.dtype,
            device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-1)
        #  deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight
github THUDM / cogdl / cogdl / modules / conv / gat_conv.py View on Github external
def forward(self, x, edge_index):
        """"""
        edge_index, _ = remove_self_loops(edge_index)
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
        return self.propagate(edge_index, x=x, num_nodes=x.size(0))
github THUDM / cogdl / cogdl / modules / conv / gcn_conv.py View on Github external
def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones(
                (edge_index.size(1), ), dtype=dtype, device=edge_index.device)
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        edge_index, _ = remove_self_loops(edge_index)
        edge_index = add_self_loops(edge_index, num_nodes)
        loop_weight = torch.full(
            (num_nodes, ),
            1 if not improved else 2,
            dtype=edge_weight.dtype,
            device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
github THUDM / cogdl / cogdl / transforms / add_self_loops.py View on Github external
def __call__(self, data):
        N = data.num_nodes
        edge_index = data.edge_index
        assert data.edge_attr is None

        edge_index = add_self_loops(edge_index, num_nodes=N)
        edge_index, _ = coalesce(edge_index, None, N, N)
        data.edge_index = edge_index
        return data