How to use the learn2learn.data.TaskGenerator function in learn2learn

To help you get started, we’ve selected a few learn2learn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github learnables / learn2learn / tests / integration / meta_mnist_tests.py View on Github external
def main(lr=0.005, maml_lr=0.01, iterations=1000, ways=5, shots=1, tps=32, fas=5, device=torch.device("cpu"),
         download_location="/tmp/mnist"):
    transformations = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        lambda x: x.view(1, 1, 28, 28),
    ])

    mnist_train = l2l.data.MetaDataset(MNIST(download_location, train=True, download=True, transform=transformations))
    # mnist_test = MNIST(file_location, train=False, download=True, transform=transformations)

    train_gen = l2l.data.TaskGenerator(mnist_train, ways=ways, tasks=10000)
    # test_gen = l2l.data.TaskGenerator(mnist_test, ways=ways)

    model = Net(ways)
    model.to(device)
    meta_model = l2l.algorithms.MAML(model, lr=maml_lr)
    opt = optim.Adam(meta_model.parameters(), lr=lr)
    loss_func = nn.NLLLoss(reduction="sum")

    for iteration in range(iterations):
        iteration_error = 0.0
        iteration_acc = 0.0
        for _ in range(tps):
            learner = meta_model.clone()
            train_task = train_gen.sample()
            valid_task = train_gen.sample(task=train_task.sampled_task)
github learnables / learn2learn / examples / vision / proto_net.py View on Github external
def main(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int,
         fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}):
    batch_size = None

    print('Begin training...')

    train_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=args.k_train)
    eval_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=args.k_test)

    monitor = f'val_{args.n_test}-shot_{args.k_test}-way_acc'
    monitor_op = np.less
    best = np.Inf
    epochs_since_last_save = 0
    for epoch in range(1, epochs + 1):
        lrs = [lr_schedule(epoch, param_group['lr']) for param_group in optimiser.param_groups]

        optimiser = set_lr(epoch, optimiser, lrs)

        epoch_logs = {}
        for batch_index in range(training_episodes):
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            support_t = train_generator.sample(shots=args.q_train)
            query_t = train_generator.sample(shots=args.q_test)
github learnables / learn2learn / examples / vision / protonet_omniglot.py View on Github external
transform=transforms.Compose([
                                l2l.vision.transforms.RandomDiscreteRotation(
                                    [0.0, 90.0, 180.0, 270.0]),
                                transforms.Resize(28, interpolation=LANCZOS),
                                transforms.ToTensor(),
                                lambda x: 1.0 - x,
                            ]),
                            download=True)
    omniglot = l2l.data.MetaDataset(omniglot)
    classes = list(range(1623))
    random.shuffle(classes)
    train_generator = l2l.data.TaskGenerator(dataset=omniglot,
                                             ways=args.k_train,
                                             classes=classes[:1100],
                                             tasks=20000)
    valid_generator = l2l.data.TaskGenerator(dataset=omniglot,
                                             ways=args.k_test,
                                             classes=classes[1100:1200],
                                             tasks=1024)
    test_generator = l2l.data.TaskGenerator(dataset=omniglot,
                                            ways=args.k_test,
                                            classes=classes[1200:],
                                            tasks=1024)

    model = OmniglotCNN()
    model.to(device, dtype=torch.double)

    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    # test_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=args.k_test, tasks=1024)
    # support_t = test_generator.sample(shots=args.q_test)
github learnables / learn2learn / examples / vision / proto_net.py View on Github external
def main(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int,
         fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}):
    batch_size = None

    print('Begin training...')

    train_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=args.k_train)
    eval_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=args.k_test)

    monitor = f'val_{args.n_test}-shot_{args.k_test}-way_acc'
    monitor_op = np.less
    best = np.Inf
    epochs_since_last_save = 0
    for epoch in range(1, epochs + 1):
        lrs = [lr_schedule(epoch, param_group['lr']) for param_group in optimiser.param_groups]

        optimiser = set_lr(epoch, optimiser, lrs)

        epoch_logs = {}
        for batch_index in range(training_episodes):
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            support_t = train_generator.sample(shots=args.q_train)
github learnables / learn2learn / examples / vision / protonet_omniglot.py View on Github external
filepath = f'./data/{param_str}.pth'

    omniglot = FullOmniglot(root='./data',
                            transform=transforms.Compose([
                                l2l.vision.transforms.RandomDiscreteRotation(
                                    [0.0, 90.0, 180.0, 270.0]),
                                transforms.Resize(28, interpolation=LANCZOS),
                                transforms.ToTensor(),
                                lambda x: 1.0 - x,
                            ]),
                            download=True)
    omniglot = l2l.data.MetaDataset(omniglot)
    classes = list(range(1623))
    random.shuffle(classes)
    train_generator = l2l.data.TaskGenerator(dataset=omniglot,
                                             ways=args.k_train,
                                             classes=classes[:1100],
                                             tasks=20000)
    valid_generator = l2l.data.TaskGenerator(dataset=omniglot,
                                             ways=args.k_test,
                                             classes=classes[1100:1200],
                                             tasks=1024)
    test_generator = l2l.data.TaskGenerator(dataset=omniglot,
                                            ways=args.k_test,
                                            classes=classes[1200:],
                                            tasks=1024)

    model = OmniglotCNN()
    model.to(device, dtype=torch.double)

    optimiser = Adam(model.parameters(), lr=1e-3)
github learnables / learn2learn / examples / vision / proto_net.py View on Github external
omniglot = FullOmniglot(root='./data',
                            transform=transforms.Compose([
                                transforms.Resize(28, interpolation=LANCZOS),
                                transforms.ToTensor(),
                                lambda x: 1.0 - x,
                            ]),
                            download=True)
    omniglot = l2l.data.MetaDataset(omniglot)

    model = OmniglotCNN()
    model.to(device, dtype=torch.double)

    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    eval_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=args.k_test)
    support_t = eval_generator.sample(shots=args.q_test)
    query_t = eval_generator.sample(shots=args.q_test)

    main(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        fit_function=proto_net_episode,
        fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                             'distance': args.distance},
    )
github learnables / learn2learn / examples / vision / meta_mnist.py View on Github external
def main(lr=0.005, maml_lr=0.01, iterations=1000, ways=5, shots=1, tps=32, fas=5, device=torch.device("cpu"),
         download_location="/tmp/mnist", test=False):
    transformations = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        lambda x: x.view(1, 1, 28, 28),
    ])

    mnist_train = l2l.data.MetaDataset(MNIST(download_location, train=True, download=True, transform=transformations))
    # mnist_test = MNIST(file_location, train=False, download=True, transform=transformations)

    train_gen = l2l.data.TaskGenerator(mnist_train, ways=ways, tasks=10000)
    # test_gen = l2l.data.TaskGenerator(mnist_test, ways=ways)

    model = Net(ways)
    model.to(device)
    meta_model = l2l.algorithms.MAML(model, lr=maml_lr)
    opt = optim.Adam(meta_model.parameters(), lr=lr)
    loss_func = nn.NLLLoss(reduction="sum")

    tqdm_bar = tqdm(range(iterations))
    for iteration in tqdm_bar:
        iteration_error = 0.0
        iteration_acc = 0.0
        for _ in range(tps):
            learner = meta_model.clone()
            train_task = train_gen.sample()
            valid_task = train_gen.sample(task=train_task.sampled_task)