Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def index_update(self, indices, values):
if isinstance(indices, tuple):
indices = tuple(
t.tensor if isinstance(t, self.__class__) else t for t in indices
)
x = self.tensor
if isinstance(indices, int):
return self.backend.tensor_scatter_nd_update(x, [[indices]], values[None])
elif isinstance(indices, tuple) and any(
isinstance(idx, slice) for idx in indices
):
if (
len(indices) == x.ndim == 2
and indices[0] == index[:]
and not isinstance(indices[1], slice)
):
x = self.backend.transpose(x)
result = self.backend.tensor_scatter_nd_update(
x, [[indices[-1]]], values[None]
)
return self.backend.transpose(result)
else:
raise NotImplementedError
elif isinstance(indices, tuple):
if all(
idx.dtype in [self.backend.int32, self.backend.int64] for idx in indices
):
indices = [
self.backend.cast(idx, self.backend.int64)
if idx.dtype == self.backend.int32