How to use the torchmeta.utils.data.task.Task function in torchmeta

To help you get started, we’ve selected a few torchmeta examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tristandeleu / pytorch-meta / torchmeta / utils / data / task.py View on Github external
def __init__(self, num_classes, transform=None, target_transform=None):
        super(Task, self).__init__(transform=transform,
            target_transform=target_transform)
        self.num_classes = num_classes
github tristandeleu / pytorch-meta / torchmeta / transforms / splitters.py View on Github external
def get_indices(self, task):
        if isinstance(task, ConcatTask):
            indices = self.get_indices_concattask(task)
        elif isinstance(task, Task):
            indices = self.get_indices_task(task)
        else:
            raise ValueError('The task must be of type `ConcatTask` or `Task`, '
                'Got type `{0}`.'.format(type(task)))
        return indices
github tristandeleu / pytorch-meta / torchmeta / utils / data / task.py View on Github external
class Task(Dataset):
    """Base class for a classification task.

    Parameters
    ----------
    num_classes : int
        Number of classes for the classification task.
    """
    def __init__(self, num_classes, transform=None, target_transform=None):
        super(Task, self).__init__(transform=transform,
            target_transform=target_transform)
        self.num_classes = num_classes


class ConcatTask(Task, ConcatDataset):
    def __init__(self, datasets, num_classes, target_transform=None):
        Task.__init__(self, num_classes)
        ConcatDataset.__init__(self, datasets)
        for task in self.datasets:
            task.target_transform_append(target_transform)

    def __getitem__(self, index):
        return ConcatDataset.__getitem__(self, index)


class SubsetTask(Task, Subset):
    def __init__(self, dataset, indices, num_classes=None,
                 target_transform=None):
        if num_classes is None:
            num_classes = dataset.num_classes
        Task.__init__(self, num_classes)
github tristandeleu / pytorch-meta / torchmeta / utils / data / task.py View on Github external
def __init__(self, datasets, num_classes, target_transform=None):
        Task.__init__(self, num_classes)
        ConcatDataset.__init__(self, datasets)
        for task in self.datasets:
            task.target_transform_append(target_transform)
github tristandeleu / pytorch-meta / torchmeta / transforms / utils.py View on Github external
def apply_wrapper(wrapper, task_or_dataset=None):
    if task_or_dataset is None:
        return wrapper

    from torchmeta.utils.data import MetaDataset
    if isinstance(task_or_dataset, Task):
        return wrapper(task_or_dataset)
    elif isinstance(task_or_dataset, MetaDataset):
        if task_or_dataset.dataset_transform is None:
            dataset_transform = wrapper
        else:
            dataset_transform = Compose([
                task_or_dataset.dataset_transform, wrapper])
        task_or_dataset.dataset_transform = dataset_transform
        return task_or_dataset
    else:
        raise NotImplementedError()
github tristandeleu / pytorch-meta / torchmeta / utils / data / task.py View on Github external
target_transform=target_transform)
        self.num_classes = num_classes


class ConcatTask(Task, ConcatDataset):
    def __init__(self, datasets, num_classes, target_transform=None):
        Task.__init__(self, num_classes)
        ConcatDataset.__init__(self, datasets)
        for task in self.datasets:
            task.target_transform_append(target_transform)

    def __getitem__(self, index):
        return ConcatDataset.__getitem__(self, index)


class SubsetTask(Task, Subset):
    def __init__(self, dataset, indices, num_classes=None,
                 target_transform=None):
        if num_classes is None:
            num_classes = dataset.num_classes
        Task.__init__(self, num_classes)
        Subset.__init__(self, dataset, indices)
        self.dataset.target_transform_append(target_transform)

    def __getitem__(self, index):
        return Subset.__getitem__(self, index)