How to use the torchmeta.transforms.Categorical function in torchmeta

To help you get started, we’ve selected a few torchmeta examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tristandeleu / pytorch-meta / torchmeta / utils / data / dataset.py View on Github external
def __getitem__(self, index):
        if isinstance(index, int):
            raise ValueError('The index of a `CombinationMetaDataset` must be '
                'a tuple of integers, and not an integer. For example, call '
                '`dataset[({0})]` to get a task with classes from 0 to {1} '
                '(got `{2}`).'.format(', '.join([str(idx)
                for idx in range(self.num_classes_per_task)]),
                self.num_classes_per_task - 1, index))
        assert len(index) == self.num_classes_per_task
        datasets = [self.dataset[i] for i in index]
        # Use deepcopy on `Categorical` target transforms, to avoid any side
        # effect across tasks.
        task = ConcatTask(datasets, self.num_classes_per_task,
            target_transform=wrap_transform(self.target_transform,
            self._copy_categorical, transform_type=Categorical))

        if self.dataset_transform is not None:
            task = self.dataset_transform(task)

        return task
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
kwargs
        Additional arguments passed to the `CIFARFS` class.

    See also
    --------
    `datasets.cifar100.CIFARFS` : Meta-dataset for the CIFAR-FS dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = ToTensor()
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = CIFARFS(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
--------
    `datasets.cub.CUB` : Meta-dataset for the Caltech-UCSD Birds dataset.
    """
    if 'num_classes_per_task' is kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        image_size = 84
        kwargs['transform'] = Compose([
            Resize(int(image_size * 1.5)),
            CenterCrop(image_size),
            ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = CUB(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
kwargs
        Additional arguments passed to the `TieredImagenet` class.

    See also
    --------
    `datasets.TieredImagenet` : Meta-dataset for the Tiered-Imagenet dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(84), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = TieredImagenet(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
kwargs
        Additional arguments passed to the `MiniImagenet` class.

    See also
    --------
    `datasets.MiniImagenet` : Meta-dataset for the Mini-Imagenet dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(84), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if test_shots is None:
        test_shots = shots

    dataset = MiniImagenet(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / utils / data / dataset.py View on Github external
def _copy_categorical(self, transform):
        assert isinstance(transform, Categorical)
        transform.reset()
        if transform.num_classes is None:
            transform.num_classes = self.num_classes_per_task
        return deepcopy(transform)
github tristandeleu / pytorch-meta / torchmeta / datasets / helpers.py View on Github external
kwargs
        Additional arguments passed to the `Omniglot` class.

    See also
    --------
    `datasets.Omniglot` : Meta-dataset for the Omniglot dataset.
    """
    if 'num_classes_per_task' in kwargs:
        warnings.warn('Both arguments `ways` and `num_classes_per_task` were '
            'set in the helper function for the number of classes per task. '
            'Ignoring the argument `ways`.', stacklevel=2)
        ways = kwargs['num_classes_per_task']
    if 'transform' not in kwargs:
        kwargs['transform'] = Compose([Resize(28), ToTensor()])
    if 'target_transform' not in kwargs:
        kwargs['target_transform'] = Categorical(ways)
    if 'class_augmentations' not in kwargs:
        kwargs['class_augmentations'] = [Rotation([90, 180, 270])]
    if test_shots is None:
        test_shots = shots

    dataset = Omniglot(folder, num_classes_per_task=ways, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset