Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, dataset, batch_size=1, shuffle=True, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None):
if collate_fn is None:
collate_fn = no_collate
if isinstance(dataset, CombinationMetaDataset) and (sampler is None):
if shuffle:
sampler = CombinationRandomSampler(dataset)
else:
sampler = CombinationSequentialSampler(dataset)
shuffle = False
super(MetaDataLoader, self).__init__(dataset, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout,
worker_init_fn=worker_init_fn)
def __init__(self, data_source):
if not isinstance(data_source, CombinationMetaDataset):
raise ValueError()
super(CombinationSequentialSampler, self).__init__(data_source)
def __init__(self, dataset, num_classes_per_task, target_transform=None,
dataset_transform=None):
if not isinstance(num_classes_per_task, int):
raise TypeError('Unknown type for `num_classes_per_task`. Expected '
'`int`, got `{0}`.'.format(type(num_classes_per_task)))
self.dataset = dataset
self.num_classes_per_task = num_classes_per_task
super(CombinationMetaDataset, self).__init__(meta_train=dataset.meta_train,
meta_val=dataset.meta_val, meta_test=dataset.meta_test,
meta_split=dataset.meta_split, target_transform=target_transform,
dataset_transform=dataset_transform)