How to use torchmeta - 10 common examples

To help you get started, we’ve selected a few torchmeta examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tristandeleu / pytorch-meta / torchmeta / modules / conv.py View on Github external
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict
from torch.nn.modules.utils import _single, _pair, _triple
from torchmeta.modules.module import MetaModule

class MetaConv1d(nn.Conv1d, MetaModule):
    __doc__ = nn.Conv1d.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        bias = params.get('bias', None)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
                            params['weight'], bias, self.stride,
                            _single(0), self.dilation, self.groups)

        return F.conv1d(input, params['weight'], bias, self.stride,
                        self.padding, self.dilation, self.groups)
github tristandeleu / pytorch-meta / torchmeta / modules / conv.py View on Github external
def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        bias = params.get('bias', None)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
                            params['weight'], bias, self.stride,
                            _pair(0), self.dilation, self.groups)

        return F.conv2d(input, params['weight'], bias, self.stride,
                        self.padding, self.dilation, self.groups)

class MetaConv3d(nn.Conv3d, MetaModule):
    __doc__ = nn.Conv3d.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        bias = params.get('bias', None)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
                                (self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv3d(F.pad(input, expanded_padding, mode='circular'),
                            params['weight'], bias, self.stride,
                            _triple(0), self.dilation, self.groups)

        return F.conv3d(input, params['weight'], bias, self.stride,
github tristandeleu / pytorch-meta / torchmeta / modules / conv.py View on Github external
def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        bias = params.get('bias', None)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
                            params['weight'], bias, self.stride,
                            _single(0), self.dilation, self.groups)

        return F.conv1d(input, params['weight'], bias, self.stride,
                        self.padding, self.dilation, self.groups)

class MetaConv2d(nn.Conv2d, MetaModule):
    __doc__ = nn.Conv2d.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        bias = params.get('bias', None)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
                            params['weight'], bias, self.stride,
                            _pair(0), self.dilation, self.groups)

        return F.conv2d(input, params['weight'], bias, self.stride,
                        self.padding, self.dilation, self.groups)
github tristandeleu / pytorch-meta / torchmeta / toy / helpers.py View on Github external
'Ignoring the argument `shots`.', stacklevel=2)
        if test_shots is not None:
            shots = kwargs['num_samples_per_task'] - test_shots
            if shots <= 0:
                raise ValueError('The argument `test_shots` ({0}) is greater '
                    'than the number of samples per task ({1}). Either use the '
                    'argument `shots` instead of `num_samples_per_task`, or '
                    'increase the value of `num_samples_per_task`.'.format(
                    test_shots, kwargs['num_samples_per_task']))
        else:
            shots = kwargs['num_samples_per_task'] // 2
    if test_shots is None:
        test_shots = shots

    dataset = Harmonic(num_samples_per_task=shots + test_shots, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / toy / helpers.py View on Github external
'Ignoring the argument `shots`.', stacklevel=2)
        if test_shots is not None:
            shots = kwargs['num_samples_per_task'] - test_shots
            if shots <= 0:
                raise ValueError('The argument `test_shots` ({0}) is greater '
                    'than the number of samples per task ({1}). Either use the '
                    'argument `shots` instead of `num_samples_per_task`, or '
                    'increase the value of `num_samples_per_task`.'.format(
                    test_shots, kwargs['num_samples_per_task']))
        else:
            shots = kwargs['num_samples_per_task'] // 2
    if test_shots is None:
        test_shots = shots

    dataset = Sinusoid(num_samples_per_task=shots + test_shots, **kwargs)
    dataset = ClassSplitter(dataset, shuffle=shuffle,
        num_train_per_class=shots, num_test_per_class=test_shots)
    dataset.seed(seed)

    return dataset
github tristandeleu / pytorch-meta / torchmeta / datasets / tcga.py View on Github external
all_sample_ids_file = os.path.join(self.root, 'all_sample_ids.json')
            with open(all_sample_ids_file, 'w') as f:
                json.dump(all_sample_ids, f)

            if os.path.isfile(csv_file):
                os.remove(csv_file)

            print('Done')

        self._process_clinical_matrices()

        # Create label files
        for split in ['train', 'val', 'test']:
            filename = os.path.join(self.root, self.filename_tasks.format(split))
            data = get_asset(self.folder, '{0}.json'.format(split), dtype='json')

            with open(filename, 'w') as f:
                labels = sorted([key.split('|', 1) for key in data])
                json.dump(labels, f)

        # Clean up
        for cancer in self.cancers:
            filename = self.clinical_matrix_filename.format(cancer)
            rawpath = os.path.join(clinical_matrices_folder, '{0}.gz'.format(filename))
            if os.path.isfile(rawpath):
                os.remove(rawpath)
github tristandeleu / pytorch-meta / torchmeta / datasets / cub.py View on Github external
return

        filename = os.path.basename(self.download_url)
        download_url(self.download_url, self.root, filename, self.tgz_md5)

        tgz_filename = os.path.join(self.root, filename)
        with tarfile.open(tgz_filename, 'r') as f:
            f.extractall(self.root)
        image_folder = os.path.join(self.root, self.image_folder)

        for split in ['train', 'val', 'test']:
            filename = os.path.join(self.root, self.filename.format(split))
            if os.path.isfile(filename):
                continue

            labels = get_asset(self.folder, '{0}.json'.format(split))
            labels_filename = os.path.join(self.root, self.filename_labels.format(split))
            with open(labels_filename, 'w') as f:
                json.dump(labels, f)

            with h5py.File(filename, 'w') as f:
                group = f.create_group('datasets')
                dtype = h5py.special_dtype(vlen=np.uint8)
                for i, label in enumerate(tqdm(labels, desc=filename)):
                    images = glob.glob(os.path.join(image_folder, label, '*.jpg'))
                    images.sort()
                    dataset = group.create_dataset(label, (len(images),), dtype=dtype)
                    for i, image in enumerate(images):
                        with open(image, 'rb') as f:
                            array = bytearray(f.read())
                            dataset[i] = np.asarray(array, dtype=np.uint8)
github tristandeleu / pytorch-meta / torchmeta / datasets / tcga.py View on Github external
def get_task_id_splits(meta_split):
    return get_asset(TCGA.folder, '{}.json'.format(meta_split), dtype='json')
github tristandeleu / pytorch-meta / torchmeta / datasets / omniglot.py View on Github external
for _, alphabet, character in characters:
                    filenames = glob.glob(os.path.join(self.root, name,
                        alphabet, character, '*.png'))
                    dataset = group.create_dataset('{0}/{1}'.format(alphabet,
                        character), (len(filenames), 105, 105), dtype='uint8')

                    for i, char_filename in enumerate(filenames):
                        image = Image.open(char_filename, mode='r').convert('L')
                        dataset[i] = ImageOps.invert(image)

                shutil.rmtree(os.path.join(self.root, name))

        for split in ['train', 'val', 'test']:
            filename = os.path.join(self.root, self.filename_labels.format(
                'vinyals_', split))
            data = get_asset(self.folder, '{0}.json'.format(split), dtype='json')

            with open(filename, 'w') as f:
                labels = sorted([('images_{0}'.format(name), alphabet, character)
                    for (name, alphabets) in data.items()
                    for (alphabet, characters) in alphabets.items()
                    for character in characters])
                json.dump(labels, f)
github tristandeleu / pytorch-meta / torchmeta / datasets / tcga.py View on Github external
def get_cancers():
    return get_asset(TCGA.folder, 'cancers.json', dtype='json')