How to use the torchmeta.utils.data.dataset.CombinationMetaDataset function in torchmeta

To help you get started, we’ve selected a few torchmeta examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tristandeleu / pytorch-meta / torchmeta / utils / data / dataloader.py View on Github external
def __init__(self, dataset, batch_size=1, shuffle=True, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None):
        if collate_fn is None:
            collate_fn = no_collate

        if isinstance(dataset, CombinationMetaDataset) and (sampler is None):
            if shuffle:
                sampler = CombinationRandomSampler(dataset)
            else:
                sampler = CombinationSequentialSampler(dataset)
            shuffle = False

        super(MetaDataLoader, self).__init__(dataset, batch_size=batch_size,
            shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
            num_workers=num_workers, collate_fn=collate_fn,
            pin_memory=pin_memory, drop_last=drop_last, timeout=timeout,
            worker_init_fn=worker_init_fn)
github tristandeleu / pytorch-meta / torchmeta / utils / data / sampler.py View on Github external
def __init__(self, data_source):
        if not isinstance(data_source, CombinationMetaDataset):
            raise ValueError()
        super(CombinationSequentialSampler, self).__init__(data_source)
github tristandeleu / pytorch-meta / torchmeta / utils / data / dataset.py View on Github external
def __init__(self, dataset, num_classes_per_task, target_transform=None,
                 dataset_transform=None):
        if not isinstance(num_classes_per_task, int):
            raise TypeError('Unknown type for `num_classes_per_task`. Expected '
                '`int`, got `{0}`.'.format(type(num_classes_per_task)))
        self.dataset = dataset
        self.num_classes_per_task = num_classes_per_task
        super(CombinationMetaDataset, self).__init__(meta_train=dataset.meta_train,
            meta_val=dataset.meta_val, meta_test=dataset.meta_test,
            meta_split=dataset.meta_split, target_transform=target_transform,
            dataset_transform=dataset_transform)