How to use the eagerpy.tensor.base.AbstractTensor function in eagerpy

To help you get started, we’ve selected a few eagerpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github jonasrauber / eagerpy / eagerpy / tensor / tensorflow.py View on Github external
from collections.abc import Iterable


def samedevice(f):
    import tensorflow as tf

    @functools.wraps(f)
    def wrapper(self, *args, **kwargs):
        with tf.device(self.tensor.device):
            out = f(self, *args, **kwargs)
        return out

    return wrapper


class TensorFlowTensor(AbstractTensor):
    def __init__(self, tensor):
        import tensorflow

        super().__init__(tensor)
        self.backend = tensorflow

    @unwrapin
    @wrapout
    def __getitem__(self, index):
        if isinstance(index, tuple):
            index = tuple(
                x.tensor if isinstance(x, self.__class__) else x for x in index
            )
            tensors = any(
                isinstance(x, self.backend.Tensor) or isinstance(x, np.ndarray)
                for x in index
github jonasrauber / eagerpy / eagerpy / tensor / pytorch.py View on Github external
from .base import AbstractTensor
from .base import wrapout
from .base import istensor
from .base import unwrapin

import numpy as np
from collections.abc import Iterable


class PyTorchTensor(AbstractTensor):
    def __init__(self, tensor):
        import torch

        super().__init__(tensor)
        self.backend = torch

    def numpy(self):
        return self.tensor.cpu().numpy()

    def item(self):
        return self.tensor.item()

    @property
    def shape(self):
        return self.tensor.shape
github jonasrauber / eagerpy / eagerpy / tensor / numpy.py View on Github external
from .base import AbstractTensor
from .base import unwrapin
from .base import wrapout


class NumPyTensor(AbstractTensor):
    def __init__(self, tensor):
        import numpy

        super().__init__(tensor)
        self.backend = numpy

    def numpy(self):
        return self.tensor

    def item(self):
        return self.tensor.item()

    @property
    def shape(self):
        return self.tensor.shape
github jonasrauber / eagerpy / eagerpy / tensor / jax.py View on Github external
from .base import AbstractTensor
from .base import unwrapin
from .base import wrapout

import numpy as onp
from collections.abc import Iterable


class JAXTensor(AbstractTensor):
    key = None

    def __init__(self, tensor):
        import jax
        from jax import numpy

        super().__init__(tensor)
        self.jax = jax
        self.backend = numpy

    def numpy(self):
        return onp.asarray(self.tensor)

    def item(self):
        return self.tensor.item()

eagerpy

EagerPy is a thin wrapper around PyTorch, TensorFlow Eager, JAX and NumPy that unifies their interface and thus allows writing code that works natively across all of them.

MIT
Latest version published 3 years ago

Package Health Score

46 / 100
Full package analysis