Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Hanqing Zeng (zengh@usc.edu); Hongkuan Zhou (hongkuaz@usc.edu)
"""
from graphsaint.globals import *
from graphsaint.pytorch_version.models import GraphSAINT
from graphsaint.pytorch_version.minibatch import Minibatch
from graphsaint.utils import *
from graphsaint.metric import *
from graphsaint.pytorch_version.utils import *
from ogb.nodeproppred import Evaluator
import torch
import time
evaluator=Evaluator(name='ogbn-products')
def evaluate_full_batch(model, minibatch, mode='val'):
"""
Full batch evaluation: for validation and test sets only.
When calculating the F1 score, we will mask the relevant root nodes
(e.g., those belonging to the val / test sets).
"""
loss,preds,labels = model.eval_step(*minibatch.one_batch(mode=mode))
if mode == 'val':
node_target = [minibatch.node_val]
elif mode == 'test':
node_target = [minibatch.node_test]
else:
assert mode == 'valtest'
node_target = [minibatch.node_val, minibatch.node_test]
labels = labels.argmax(dim=-1, keepdim=True)