How to use the geoopt.manifolds.Manifold function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / geoopt / manifolds.py View on Github external
return x + t * u

    def _inner(self, x, u, v):
        return u * v

    def _proju(self, x, u):
        return u

    def _projx(self, x):
        return x

    def _transp(self, x, u, v, t):
        return v


class Stiefel(Manifold):
    name = "Stiefel"
    ndim = 2
    reversible = True

    def check_dims(self, x):
        return x.dim() >= 2

    def amat(self, x, u, project=True):
        if project:
            u = self.proju(x, u)
        return u @ x.transpose(-1, -2) - x @ u.transpose(-1, -2)

    def _proju(self, x, u):
        p = -0.5 * x @ x.transpose(-1, -2)
        p[..., range(x.shape[-2]), range(x.shape[-2])] += 1
        return p @ u
github geoopt / geoopt / geoopt / utils.py View on Github external
Parameters
    ----------
    instance : geoopt.Manifold
        check if a given manifold is compatible with cls API
    cls : type
        manifold type

    Returns
    -------
    bool
        comparison result
    """
    if not issubclass(cls, geoopt.manifolds.Manifold):
        raise TypeError("`cls` should be a subclass of geoopt.manifolds.Manifold")
    if not isinstance(instance, geoopt.manifolds.Manifold):
        return False
    else:
        # this is the case to care about, Scaled class is a proxy, but fails instance checks
        while isinstance(instance, geoopt.Scaled):
            instance = instance.base
        return isinstance(instance, cls)
github geoopt / geoopt / geoopt / manifolds.py View on Github external
    @abc.abstractmethod
    def _proju(self, x, u):
        raise NotImplementedError

    @abc.abstractmethod
    def _projx(self, x):
        raise NotImplementedError

    def __repr__(self):
        return self.name + " manifold"

    def __eq__(self, other):
        return type(self) is type(other)


class Rn(Manifold):
    name = "Rn"
    ndim = 0
    reversible = True

    def check_dims(self, x):
        return True

    def _retr(self, x, u, t):
        return x + t * u

    def _inner(self, x, u, v):
        return u * v

    def _proju(self, x, u):
        return u
github geoopt / geoopt / geoopt / tensor.py View on Github external
    @insert_docs(Manifold.proju.__doc__, r"\s+x : .+\n.+", "")
    def proju(self, u: torch.Tensor, **kwargs) -> torch.Tensor:
        return self.manifold.proju(self, u, **kwargs)