Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
train_transforms = [
l2l.data.transforms.NWays(train_dataset, ways),
l2l.data.transforms.KShots(train_dataset, 2*shots),
l2l.data.transforms.LoadData(train_dataset),
l2l.data.transforms.RemapLabels(train_dataset),
l2l.data.transforms.ConsecutiveLabels(train_dataset),
]
train_tasks = l2l.data.TaskDataset(train_dataset,
task_transforms=train_transforms,
num_tasks=20000)
valid_transforms = [
l2l.data.transforms.NWays(valid_dataset, ways),
l2l.data.transforms.KShots(valid_dataset, 2*shots),
l2l.data.transforms.LoadData(valid_dataset),
l2l.data.transforms.ConsecutiveLabels(train_dataset),
l2l.data.transforms.RemapLabels(valid_dataset),
]
valid_tasks = l2l.data.TaskDataset(valid_dataset,
task_transforms=valid_transforms,
num_tasks=600)
test_transforms = [
l2l.data.transforms.NWays(test_dataset, ways),
l2l.data.transforms.KShots(test_dataset, 2*shots),
l2l.data.transforms.LoadData(test_dataset),
l2l.data.transforms.RemapLabels(test_dataset),
l2l.data.transforms.ConsecutiveLabels(train_dataset),
]
test_tasks = l2l.data.TaskDataset(test_dataset,
task_transforms=test_transforms,
valid_transforms = [
l2l.data.transforms.NWays(valid_dataset, ways),
l2l.data.transforms.KShots(valid_dataset, 2*shots),
l2l.data.transforms.LoadData(valid_dataset),
l2l.data.transforms.ConsecutiveLabels(train_dataset),
l2l.data.transforms.RemapLabels(valid_dataset),
]
valid_tasks = l2l.data.TaskDataset(valid_dataset,
task_transforms=valid_transforms,
num_tasks=600)
test_transforms = [
l2l.data.transforms.NWays(test_dataset, ways),
l2l.data.transforms.KShots(test_dataset, 2*shots),
l2l.data.transforms.LoadData(test_dataset),
l2l.data.transforms.RemapLabels(test_dataset),
l2l.data.transforms.ConsecutiveLabels(train_dataset),
]
test_tasks = l2l.data.TaskDataset(test_dataset,
task_transforms=test_transforms,
num_tasks=600)
# Create model
model = l2l.vision.models.MiniImagenetCNN(ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(size_average=True, reduction='mean')
for iteration in range(num_iterations):
opt.zero_grad()
if cuda and th.cuda.device_count():
th.cuda.manual_seed(seed)
device = th.device('cuda')
# Create Datasets
train_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='train')
valid_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='validation')
test_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='test')
train_dataset = l2l.data.MetaDataset(train_dataset)
valid_dataset = l2l.data.MetaDataset(valid_dataset)
test_dataset = l2l.data.MetaDataset(test_dataset)
train_transforms = [
NWays(train_dataset, ways),
KShots(train_dataset, 2*shots),
LoadData(train_dataset),
RemapLabels(train_dataset),
ConsecutiveLabels(train_dataset),
]
train_tasks = l2l.data.TaskDataset(train_dataset,
task_transforms=train_transforms,
num_tasks=20000)
valid_transforms = [
NWays(valid_dataset, ways),
KShots(valid_dataset, 2*shots),
LoadData(valid_dataset),
ConsecutiveLabels(train_dataset),
RemapLabels(valid_dataset),
]
valid_tasks = l2l.data.TaskDataset(valid_dataset,
task_transforms=valid_transforms,
valid_transforms = [
NWays(valid_dataset, ways),
KShots(valid_dataset, 2*shots),
LoadData(valid_dataset),
ConsecutiveLabels(train_dataset),
RemapLabels(valid_dataset),
]
valid_tasks = l2l.data.TaskDataset(valid_dataset,
task_transforms=valid_transforms,
num_tasks=600)
test_transforms = [
NWays(test_dataset, ways),
KShots(test_dataset, 2*shots),
LoadData(test_dataset),
RemapLabels(test_dataset),
ConsecutiveLabels(train_dataset),
]
test_tasks = l2l.data.TaskDataset(test_dataset,
task_transforms=test_transforms,
num_tasks=600)
# Create model
model = l2l.vision.models.MiniImagenetCNN(ways)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')
for iteration in range(num_iterations):
opt.zero_grad()