How to use the eagerpy.index function in eagerpy

To help you get started, we’ve selected a few eagerpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github jonasrauber / eagerpy / eagerpy / tensor / tensorflow.py View on Github external
def index_update(self, indices, values):
        if isinstance(indices, tuple):
            indices = tuple(
                t.tensor if isinstance(t, self.__class__) else t for t in indices
            )

        x = self.tensor
        if isinstance(indices, int):
            return self.backend.tensor_scatter_nd_update(x, [[indices]], values[None])
        elif isinstance(indices, tuple) and any(
            isinstance(idx, slice) for idx in indices
        ):
            if (
                len(indices) == x.ndim == 2
                and indices[0] == index[:]
                and not isinstance(indices[1], slice)
            ):
                x = self.backend.transpose(x)
                result = self.backend.tensor_scatter_nd_update(
                    x, [[indices[-1]]], values[None]
                )
                return self.backend.transpose(result)
            else:
                raise NotImplementedError
        elif isinstance(indices, tuple):
            if all(
                idx.dtype in [self.backend.int32, self.backend.int64] for idx in indices
            ):
                indices = [
                    self.backend.cast(idx, self.backend.int64)
                    if idx.dtype == self.backend.int32

eagerpy

EagerPy is a thin wrapper around PyTorch, TensorFlow Eager, JAX and NumPy that unifies their interface and thus allows writing code that works natively across all of them.

MIT
Latest version published 3 years ago

Package Health Score

46 / 100
Full package analysis