Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@unwrapin
@wrapout
def onehot_like(self, indices, *, value=1):
assert self.tensor.ndim == 2
assert indices.ndim == 1
x = self.backend.zeros_like(self.tensor)
rows = np.arange(len(x))
x[rows, indices] = value
return x
@unwrapin
@wrapout
def onehot_like(self, indices, *, value=1):
assert self.tensor.ndim == 2
assert indices.ndim == 1
x = self.backend.arange(self.tensor.shape[1]).reshape(1, -1)
indices = indices.reshape(-1, 1)
return x == indices
@unwrapin
@wrapout
def tile(self, multiples):
assert len(multiples) == self.ndim
return self.backend.tile(self.tensor, multiples)
@unwrapin
@wrapout
def index_update(self, indices, values):
if isinstance(indices, tuple):
indices = tuple(
t.tensor if isinstance(t, self.__class__) else t for t in indices
)
return self.jax.ops.index_update(self.tensor, indices, values)
@unwrapin
@wrapout
def index_update(self, indices, values):
if isinstance(indices, tuple):
indices = tuple(
t.tensor if isinstance(t, self.__class__) else t for t in indices
)
x = self.tensor
if isinstance(indices, int):
return self.backend.tensor_scatter_nd_update(x, [[indices]], values[None])
elif isinstance(indices, tuple) and any(
isinstance(idx, slice) for idx in indices
):
if (
len(indices) == x.ndim == 2
and indices[0] == index[:]
@unwrapin
@wrapout
def logical_and(self, other):
assert self.dtype == self.backend.bool
return self.tensor & other
@unwrapin
@wrapout
def maximum(self, other):
return self.backend.maximum(self.tensor, other)
@unwrapin
@wrapout
def onehot_like(self, indices, *, value=1):
assert self.tensor.ndim == 2
assert indices.ndim == 1
assert len(indices) == len(self.tensor)
value = self.backend.cast(value, self.tensor.dtype)
return self.backend.one_hot(
indices,
depth=self.tensor.shape[-1],
on_value=value,
dtype=self.tensor.dtype,
)
@unwrapin
@wrapout
def logical_or(self, other):
assert self.dtype == self.backend.bool_
return self.backend.logical_or(self.tensor, other)
@unwrapin
@wrapout
def logical_and(self, other):
assert self.dtype == self.backend.bool
return self.backend.logical_and(self.tensor, other)