Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for i, line in enumerate(context):
edge_index[:, i] = list(map(int, line.strip().split("\t")))
edge_index = torch.from_numpy(edge_index).to(torch.int)
with open(cmty_path) as f:
context = f.readlines()
print("class number: ", len(context))
label = np.zeros((num_node, len(context)))
for i, line in enumerate(context):
line = map(int, line.strip().split("\t"))
for node in line:
label[node, i] = 1
y = torch.from_numpy(label).to(torch.float)
data = Data(x=None, edge_index=edge_index, y=y)
return data
valid_data[items[0]] = [[], []]
valid_data[items[0]][1 - int(items[3])].append(
[int(items[1]), int(items[2])]
)
test_data = {}
with open(osp.join(folder, "{}".format("test.txt")), "r") as f:
for line in f:
items = line.strip().split()
if items[0] not in test_data:
test_data[items[0]] = [[], []]
test_data[items[0]][1 - int(items[3])].append(
[int(items[1]), int(items[2])]
)
data = Data()
data.train_data = train_data
data.valid_data = valid_data
data.test_data = test_data
return data
context = f.readlines()
print("class number: ", len(context))
label = np.zeros((num_node, len(context)))
for i, line in enumerate(context):
line = map(int, line.strip().split("\t"))
for node in line:
label[node, i] = 1
y = torch.from_numpy(label).to(torch.float)
data = Data(x=None, edge_index=edge_index, y=y)
return data
class EdgelistLabel(Dataset):
r"""networks from the https://github.com/THUDM/ProNE/raw/master/data
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (:obj:`"Wikipedia"`).
"""
url = "https://github.com/THUDM/ProNE/raw/master/data"
def __init__(self, root, name):
self.name = name
super(EdgelistLabel, self).__init__(root)
self.data = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
for line in f:
items = line.strip().split()
if items[0] not in test_data:
test_data[items[0]] = [[], []]
test_data[items[0]][1 - int(items[3])].append(
[int(items[1]), int(items[2])]
)
data = Data()
data.train_data = train_data
data.valid_data = valid_data
data.test_data = test_data
return data
class GatneDataset(Dataset):
r"""The network datasets "Amazon", "Twitter" and "YouTube" from the
`"Representation Learning for Attributed Multiplex Heterogeneous Network"
`_ paper.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (:obj:`"Amazon"`,
:obj:`"Twitter"`, :obj:`"YouTube"`).
"""
url = "https://github.com/THUDM/GATNE/raw/master/data"
def __init__(self, root, name):
self.name = name
super(GatneDataset, self).__init__(root)
self.data = torch.load(self.processed_paths[0])
def collate(self, data_list):
r"""Collates a python list of data objects to the internal storage
format of :class:`cogdl.data.InMemoryDataset`."""
keys = data_list[0].keys
data = Data()
for key in keys:
data[key] = []
slices = {key: [0] for key in keys}
for item, key in product(data_list, keys):
data[key].append(item[key])
s = slices[key][-1] + item[key].size(item.cat_dim(key, item[key]))
slices[key].append(s)
for key in keys:
data[key] = torch.cat(
data[key], dim=data_list[0].cat_dim(key, data_list[0][key])
)
slices[key] = torch.LongTensor(slices[key])
import re
import torch
from cogdl.data import Data
class Batch(Data):
r"""A plain old python object modeling a batch of graphs as one big
(dicconnected) graph. With :class:`cogdl.data.Data` being the
base class, all its methods can also be used here.
In addition, single graphs can be reconstructed via the assignment vector
:obj:`batch`, which maps each node to its respective graph identifier.
"""
def __init__(self, batch=None, **kwargs):
super(Batch, self).__init__(**kwargs)
self.batch = batch
@staticmethod
def from_data_list(data_list):
r"""Constructs a batch object from a python list holding
:class:`cogdl.data.Data` objects.
The assignment vector :obj:`batch` is created on the fly."""
def get(self, idx):
data = Data()
for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
s = list(repeat(slice(None), item.dim()))
s[self.data.cat_dim(key, item)] = slice(slices[idx], slices[idx + 1])
data[key] = item[s]
return data
class CoraDataset(Planetoid):
def __init__(self):
dataset = "Cora"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(CoraDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("citeseer")
class CiteSeerDataset(Planetoid):
def __init__(self):
dataset = "CiteSeer"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(CiteSeerDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("pubmed")
class PubMedDataset(Planetoid):
def __init__(self):
dataset = "PubMed"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(PubMedDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("reddit")
class RedditDataset(Reddit):
def __init__(self):
dataset = "Reddit"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(RedditDataset, self).__init__(path, T.TargetIndegree())
data = read_gatne_data(self.raw_dir)
torch.save(data, self.processed_paths[0])
def __repr__(self):
return "{}()".format(self.name)
@register_dataset("amazon")
class AmazonDataset(GatneDataset):
def __init__(self):
dataset = "amazon"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(AmazonDataset, self).__init__(path, dataset)
@register_dataset("twitter")
class TwitterDataset(GatneDataset):
def __init__(self):
dataset = "twitter"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(TwitterDataset, self).__init__(path, dataset)
@register_dataset("youtube")
class YouTubeDataset(GatneDataset):
def __init__(self):
dataset = "youtube"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(YouTubeDataset, self).__init__(path, dataset)
dataset, filename = "blogcatalog", "blogcatalog"
url = "http://leitang.net/code/social-dimension/data/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(BlogcatalogDataset, self).__init__(path, filename, url)
@register_dataset("flickr")
class FlickrDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "flickr", "flickr"
url = "http://leitang.net/code/social-dimension/data/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(FlickrDataset, self).__init__(path, filename, url)
@register_dataset("wikipedia")
class WikipediaDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "wikipedia", "POS"
url = "http://snap.stanford.edu/node2vec/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(WikipediaDataset, self).__init__(path, filename, url)
@register_dataset("ppi")
class PPIDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "ppi", "Homo_sapiens"
url = "http://snap.stanford.edu/node2vec/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(PPIDataset, self).__init__(path, filename, url)