How to use the torchmeta.utils.data.BatchMetaDataLoader function in torchmeta

To help you get started, we’ve selected a few torchmeta examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tristandeleu / pytorch-meta / examples / maml / train.py View on Github external
def train(args):
    dataset = omniglot(args.folder, shots=args.num_shots, ways=args.num_ways,
        shuffle=True, test_shots=15, meta_train=True, download=args.download)
    dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(1, args.num_ways,
        hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)
github tristandeleu / pytorch-meta / examples / protonet / train.py View on Github external
def train(args):
    dataset = omniglot(args.folder, shots=args.num_shots, ways=args.num_ways,
        shuffle=True, test_shots=15, meta_train=True, download=args.download)
    dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers)

    model = PrototypicalNetwork(1, args.embedding_size,
        hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)