Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, num_classes, transform=None, target_transform=None):
super(Task, self).__init__(transform=transform,
target_transform=target_transform)
self.num_classes = num_classes
def get_indices(self, task):
if isinstance(task, ConcatTask):
indices = self.get_indices_concattask(task)
elif isinstance(task, Task):
indices = self.get_indices_task(task)
else:
raise ValueError('The task must be of type `ConcatTask` or `Task`, '
'Got type `{0}`.'.format(type(task)))
return indices
class Task(Dataset):
"""Base class for a classification task.
Parameters
----------
num_classes : int
Number of classes for the classification task.
"""
def __init__(self, num_classes, transform=None, target_transform=None):
super(Task, self).__init__(transform=transform,
target_transform=target_transform)
self.num_classes = num_classes
class ConcatTask(Task, ConcatDataset):
def __init__(self, datasets, num_classes, target_transform=None):
Task.__init__(self, num_classes)
ConcatDataset.__init__(self, datasets)
for task in self.datasets:
task.target_transform_append(target_transform)
def __getitem__(self, index):
return ConcatDataset.__getitem__(self, index)
class SubsetTask(Task, Subset):
def __init__(self, dataset, indices, num_classes=None,
target_transform=None):
if num_classes is None:
num_classes = dataset.num_classes
Task.__init__(self, num_classes)
def __init__(self, datasets, num_classes, target_transform=None):
Task.__init__(self, num_classes)
ConcatDataset.__init__(self, datasets)
for task in self.datasets:
task.target_transform_append(target_transform)
def apply_wrapper(wrapper, task_or_dataset=None):
if task_or_dataset is None:
return wrapper
from torchmeta.utils.data import MetaDataset
if isinstance(task_or_dataset, Task):
return wrapper(task_or_dataset)
elif isinstance(task_or_dataset, MetaDataset):
if task_or_dataset.dataset_transform is None:
dataset_transform = wrapper
else:
dataset_transform = Compose([
task_or_dataset.dataset_transform, wrapper])
task_or_dataset.dataset_transform = dataset_transform
return task_or_dataset
else:
raise NotImplementedError()
target_transform=target_transform)
self.num_classes = num_classes
class ConcatTask(Task, ConcatDataset):
def __init__(self, datasets, num_classes, target_transform=None):
Task.__init__(self, num_classes)
ConcatDataset.__init__(self, datasets)
for task in self.datasets:
task.target_transform_append(target_transform)
def __getitem__(self, index):
return ConcatDataset.__getitem__(self, index)
class SubsetTask(Task, Subset):
def __init__(self, dataset, indices, num_classes=None,
target_transform=None):
if num_classes is None:
num_classes = dataset.num_classes
Task.__init__(self, num_classes)
Subset.__init__(self, dataset, indices)
self.dataset.target_transform_append(target_transform)
def __getitem__(self, index):
return Subset.__getitem__(self, index)