Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from collections.abc import Iterable
def samedevice(f):
import tensorflow as tf
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
with tf.device(self.tensor.device):
out = f(self, *args, **kwargs)
return out
return wrapper
class TensorFlowTensor(AbstractTensor):
def __init__(self, tensor):
import tensorflow
super().__init__(tensor)
self.backend = tensorflow
@unwrapin
@wrapout
def __getitem__(self, index):
if isinstance(index, tuple):
index = tuple(
x.tensor if isinstance(x, self.__class__) else x for x in index
)
tensors = any(
isinstance(x, self.backend.Tensor) or isinstance(x, np.ndarray)
for x in index
from .base import AbstractTensor
from .base import wrapout
from .base import istensor
from .base import unwrapin
import numpy as np
from collections.abc import Iterable
class PyTorchTensor(AbstractTensor):
def __init__(self, tensor):
import torch
super().__init__(tensor)
self.backend = torch
def numpy(self):
return self.tensor.cpu().numpy()
def item(self):
return self.tensor.item()
@property
def shape(self):
return self.tensor.shape
from .base import AbstractTensor
from .base import unwrapin
from .base import wrapout
class NumPyTensor(AbstractTensor):
def __init__(self, tensor):
import numpy
super().__init__(tensor)
self.backend = numpy
def numpy(self):
return self.tensor
def item(self):
return self.tensor.item()
@property
def shape(self):
return self.tensor.shape
from .base import AbstractTensor
from .base import unwrapin
from .base import wrapout
import numpy as onp
from collections.abc import Iterable
class JAXTensor(AbstractTensor):
key = None
def __init__(self, tensor):
import jax
from jax import numpy
super().__init__(tensor)
self.jax = jax
self.backend = numpy
def numpy(self):
return onp.asarray(self.tensor)
def item(self):
return self.tensor.item()