How to use the learn2learn.data.TaskDataset function in learn2learn

To help you get started, we’ve selected a few learn2learn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github learnables / learn2learn / tests / integration / maml_miniimagenet_test_notravis.py View on Github external
l2l.data.transforms.LoadData(train_dataset),
        l2l.data.transforms.RemapLabels(train_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        l2l.data.transforms.NWays(valid_dataset, ways),
        l2l.data.transforms.KShots(valid_dataset, 2*shots),
        l2l.data.transforms.LoadData(valid_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
        l2l.data.transforms.RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        l2l.data.transforms.NWays(test_dataset, ways),
        l2l.data.transforms.KShots(test_dataset, 2*shots),
        l2l.data.transforms.LoadData(test_dataset),
        l2l.data.transforms.RemapLabels(test_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
github learnables / learn2learn / tests / integration / maml_miniimagenet_test_notravis.py View on Github external
l2l.data.transforms.LoadData(valid_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
        l2l.data.transforms.RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        l2l.data.transforms.NWays(test_dataset, ways),
        l2l.data.transforms.KShots(test_dataset, 2*shots),
        l2l.data.transforms.LoadData(test_dataset),
        l2l.data.transforms.RemapLabels(test_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(size_average=True, reduction='mean')

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
github learnables / learn2learn / examples / vision / maml_omniglot.py View on Github external
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    train_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        l2l.data.transforms.FilterLabels(dataset, classes[1100:1200]),
        l2l.data.transforms.NWays(dataset, ways),
        l2l.data.transforms.KShots(dataset, 2*shots),
        l2l.data.transforms.LoadData(dataset),
        l2l.data.transforms.RemapLabels(dataset),
        l2l.data.transforms.ConsecutiveLabels(dataset),
        l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    valid_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=1024)

    test_transforms = [
        l2l.data.transforms.FilterLabels(dataset, classes[1200:]),
        l2l.data.transforms.NWays(dataset, ways),
        l2l.data.transforms.KShots(dataset, 2*shots),
        l2l.data.transforms.LoadData(dataset),
        l2l.data.transforms.RemapLabels(dataset),
        l2l.data.transforms.ConsecutiveLabels(dataset),
        l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    test_tasks = l2l.data.TaskDataset(dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=1024)
github learnables / learn2learn / examples / vision / maml_omniglot.py View on Github external
]),
                                                download=True)
    dataset = l2l.data.MetaDataset(omniglot)
    classes = list(range(1623))
    random.shuffle(classes)

    train_transforms = [
        l2l.data.transforms.FilterLabels(dataset, classes[:1100]),
        l2l.data.transforms.NWays(dataset, ways),
        l2l.data.transforms.KShots(dataset, 2*shots),
        l2l.data.transforms.LoadData(dataset),
        l2l.data.transforms.RemapLabels(dataset),
        l2l.data.transforms.ConsecutiveLabels(dataset),
        l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    train_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        l2l.data.transforms.FilterLabels(dataset, classes[1100:1200]),
        l2l.data.transforms.NWays(dataset, ways),
        l2l.data.transforms.KShots(dataset, 2*shots),
        l2l.data.transforms.LoadData(dataset),
        l2l.data.transforms.RemapLabels(dataset),
        l2l.data.transforms.ConsecutiveLabels(dataset),
        l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    valid_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=1024)
github learnables / learn2learn / examples / vision / meta_mnist.py View on Github external
def main(lr=0.005, maml_lr=0.01, iterations=1000, ways=5, shots=1, tps=32, fas=5, device=torch.device("cpu"),
         download_location='./data'):
    transformations = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        lambda x: x.view(1, 28, 28),
    ])

    mnist_train = l2l.data.MetaDataset(MNIST(download_location,
                                             train=True,
                                             download=True,
                                             transform=transformations))

    train_tasks = l2l.data.TaskDataset(mnist_train,
                                       task_transforms=[
                                            l2l.data.transforms.NWays(mnist_train, ways),
                                            l2l.data.transforms.KShots(mnist_train, 2*shots),
                                            l2l.data.transforms.LoadData(mnist_train),
                                            l2l.data.transforms.RemapLabels(mnist_train),
                                            l2l.data.transforms.ConsecutiveLabels(mnist_train),
                                       ],
                                       num_tasks=1000)

    model = Net(ways)
    model.to(device)
    meta_model = l2l.algorithms.MAML(model, lr=maml_lr)
    opt = optim.Adam(meta_model.parameters(), lr=lr)
    loss_func = nn.NLLLoss(reduction='mean')

    for iteration in range(iterations):
github learnables / learn2learn / examples / vision / maml_miniimagenet.py View on Github external
LoadData(valid_dataset),
        ConsecutiveLabels(train_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        NWays(test_dataset, ways),
        KShots(test_dataset, 2*shots),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0