How to use the geoopt.linalg function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def sphere_compliment_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    complement = torch.rand(shape[-1], 1, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(complement)
    P = -Q @ Q.transpose(-1, -2)
    P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_qr(A):
    q, r = geoopt.linalg.qr(A)
    with torch.no_grad():
        for i, a in enumerate(A):
            qt, rt = torch.qr(a)
            np.testing.assert_allclose(q.detach()[i], qt.detach())
            np.testing.assert_allclose(r.detach()[i], rt.detach())
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_svd(A):
    u, d, v = geoopt.linalg.svd(A)
    with torch.no_grad():
        for i, a in enumerate(A):
            ut, dt, vt = torch.svd(a)
            np.testing.assert_allclose(u.detach()[i], ut.detach())
            np.testing.assert_allclose(d.detach()[i], dt.detach())
            np.testing.assert_allclose(v.detach()[i], vt.detach())
    u.sum().backward()  # this should work
github geoopt / geoopt / geoopt / manifolds / sphere.py View on Github external
def _configure_manifold_intersection(self, intersection: torch.Tensor):
        Q, _ = geoopt.linalg.batch_linalg.qr(intersection)
        self.register_buffer("projector", Q @ Q.transpose(-1, -2))
github geoopt / geoopt / geoopt / manifolds / sphere.py View on Github external
def _configure_manifold_complement(self, complement: torch.Tensor):
        Q, _ = geoopt.linalg.batch_linalg.qr(complement)
        P = -Q @ Q.transpose(-1, -2)
        P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1
        self.register_buffer("projector", P)
github geoopt / geoopt / geoopt / manifolds / stiefel.py View on Github external
def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        return u - x @ linalg.batch_linalg.sym(x.transpose(-1, -2) @ u)
github geoopt / geoopt / geoopt / manifolds / stiefel.py View on Github external
def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        xtu = x.transpose(-1, -2) @ u
        utu = u.transpose(-1, -2) @ u
        eye = torch.zeros_like(utu)
        eye[..., torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1
        logw = linalg.block_matrix(((xtu, -utu), (eye, xtu)))
        w = linalg.expm(logw)
        z = torch.cat((linalg.expm(-xtu), torch.zeros_like(utu)), dim=-2)
        y = torch.cat((x, u), dim=-1) @ w @ z
        return y
github geoopt / geoopt / geoopt / manifolds / stiefel.py View on Github external
def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        xtu = x.transpose(-1, -2) @ u
        utu = u.transpose(-1, -2) @ u
        eye = torch.zeros_like(utu)
        eye[..., torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1
        logw = linalg.block_matrix(((xtu, -utu), (eye, xtu)))
        w = linalg.expm(logw)
        z = torch.cat((linalg.expm(-xtu), torch.zeros_like(utu)), dim=-2)
        y = torch.cat((x, u), dim=-1) @ w @ z
        return y
github geoopt / geoopt / geoopt / manifolds / stiefel.py View on Github external
----------
        size : shape
            the desired output shape
        dtype : torch.dtype
            desired dtype
        device : torch.device
            desired device

        Returns
        -------
        ManifoldTensor
            random point on Stiefel manifold
        """
        self._assert_check_shape(size2shape(*size), "x")
        tens = torch.randn(*size, device=device, dtype=dtype)
        return ManifoldTensor(linalg.qr(tens)[0], manifold=self)
github geoopt / geoopt / geoopt / manifolds / birkhoff_polytope.py View on Github external
def proju(self, x, u):
        # takes batch data
        # batch_size, n, _ = x.shape
        x_shape = x.shape
        x = x.reshape(-1, x_shape[-2], x_shape[-1])
        batch_size, n = x.shape[0:2]

        e = torch.ones(batch_size, n, 1)
        I = torch.unsqueeze(torch.eye(x.shape[-1]), 0).repeat(batch_size, 1, 1)

        mu = x * u

        A = linalg.block_matrix([[I, x], [torch.transpose(x, 1, 2), I]])

        B = A[:, :, 1:]
        b = torch.cat(
            [
                torch.sum(mu, dim=2, keepdim=True),
                torch.transpose(torch.sum(mu, dim=1, keepdim=True), 1, 2),
            ],
            dim=1,
        )

        zeta, _ = torch.solve(
            B.transpose(1, 2) @ (b - A[:, :, 0:1]), B.transpose(1, 2) @ B
        )
        alpha = torch.cat([torch.ones(batch_size, 1, 1), zeta[:, 0 : n - 1]], dim=1)
        beta = zeta[:, n - 1 : 2 * n - 1]
        rgrad = mu - (alpha @ e.transpose(1, 2) + e @ beta.transpose(1, 2)) * x