How to use the geoopt.ManifoldParameter function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / tests / test_tensor_api.py View on Github external
def test_compare_manifolds():
    m1 = geoopt.Euclidean()
    m2 = geoopt.Euclidean(ndim=1)
    tensor = geoopt.ManifoldTensor(10, manifold=m1)
    with pytest.raises(ValueError) as e:
        _ = geoopt.ManifoldParameter(tensor, manifold=m2)
    assert e.match("Manifolds do not match")
github geoopt / geoopt / tests / test_rsgd.py View on Github external
def test_init_manifold():
    torch.manual_seed(42)
    stiefel = geoopt.manifolds.Stiefel()
    rn = geoopt.manifolds.Euclidean()
    x0 = torch.randn(10, 10)
    x1 = torch.randn(10, 10)
    with torch.no_grad():
        p0 = geoopt.ManifoldParameter(x0, manifold=stiefel).proj_()
    p1 = geoopt.ManifoldParameter(x1, manifold=rn)
    p0.grad = torch.zeros_like(p0)
    p1.grad = torch.zeros_like(p1)
    p0old = p0.clone()
    p1old = p1.clone()
    opt = geoopt.optim.RiemannianSGD([p0, p1], lr=1, stabilize=1)
    opt.zero_grad()
    opt.step()
    assert not np.allclose(p0.data, p0old.data)
    assert p0.is_contiguous()
    np.testing.assert_allclose(p1.data, p1old.data)
    np.testing.assert_allclose(p0.data, stiefel.projx(p0old.data), atol=1e-4)
github geoopt / geoopt / tests / test_rhmc.py View on Github external
def __init__(self, mu, sigma):
            super().__init__()
            self.d = torch.distributions.Normal(mu, sigma)
            self.x = geoopt.ManifoldParameter(
                torch.randn_like(mu), manifold=geoopt.Stiefel()
            )
github geoopt / geoopt / tests / test_rsgd.py View on Github external
def test_rsgd_stiefel(params):
    stiefel = geoopt.manifolds.Stiefel()
    torch.manual_seed(42)
    with torch.no_grad():
        X = geoopt.ManifoldParameter(torch.randn(20, 10), manifold=stiefel).proj_()
    Xstar = torch.randn(20, 10)
    Xstar.set_(stiefel.projx(Xstar))

    def closure():
        optim.zero_grad()
        loss = (X - Xstar).pow(2).sum()
        # manifold constraint that makes optimization hard if violated
        loss += (X.t() @ X - torch.eye(X.shape[1])).pow(2).sum() * 100
        loss.backward()
        return loss.item()

    optim = geoopt.optim.RiemannianSGD([X], **params)
    assert (X - Xstar).norm() > 1e-5
    for _ in range(10000):
        if (X - Xstar).norm() < 1e-5:
            break
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_pickle3():
    t = torch.ones(10)
    span = torch.randn(10, 2)
    sub_sphere = geoopt.manifolds.Sphere(intersection=span)
    p = geoopt.ManifoldParameter(t, manifold=sub_sphere)
    with tempfile.TemporaryDirectory() as path:
        torch.save(p, os.path.join(path, "tens.t7"))
        p1 = torch.load(os.path.join(path, "tens.t7"))
    assert isinstance(p1, geoopt.ManifoldParameter)
    assert p.stride() == p1.stride()
    assert p.storage_offset() == p1.storage_offset()
    assert p.requires_grad == p1.requires_grad
    np.testing.assert_allclose(p.detach(), p1.detach())
    assert isinstance(p.manifold, type(p1.manifold))
    np.testing.assert_allclose(p.manifold.projector, p1.manifold.projector)
github geoopt / geoopt / tests / test_adam.py View on Github external
def test_adam_stiefel(params):
    stiefel = geoopt.manifolds.Stiefel()
    torch.manual_seed(42)
    with torch.no_grad():
        X = geoopt.ManifoldParameter(torch.randn(20, 10), manifold=stiefel).proj_()
    Xstar = torch.randn(20, 10)
    Xstar.set_(stiefel.projx(Xstar))

    def closure():
        optim.zero_grad()
        loss = (X - Xstar).pow(2).sum()
        # manifold constraint that makes optimization hard if violated
        loss += (X.t() @ X - torch.eye(X.shape[1])).pow(2).sum() * 100
        loss.backward()
        return loss.item()

    optim = geoopt.optim.RiemannianAdam([X], stabilize=4500, **params)
    assert (X - Xstar).norm() > 1e-5
    for _ in range(10000):
        if (X - Xstar).norm() < 1e-5:
            break