How to use the geoopt.PoincareBall function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / tests / test_adam.py View on Github external
def test_adam_poincare():
    torch.manual_seed(44)
    ideal = torch.tensor([0.5, 0.5])
    start = torch.randn(2) / 2
    start = geoopt.manifolds.poincare.math.expmap0(start, c=1.0)
    start = geoopt.ManifoldParameter(start, manifold=geoopt.PoincareBall())

    def closure():
        optim.zero_grad()
        loss = geoopt.manifolds.poincare.math.dist(start, ideal) ** 2
        loss.backward()
        return loss.item()

    optim = geoopt.optim.RiemannianAdam([start], lr=1e-2)

    for _ in range(2000):
        optim.step(closure)
    np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
github geoopt / geoopt / tests / test_product_manifold.py View on Github external
def test_dtype_checked_properly():
    p1 = PoincareBall()
    p2 = PoincareBall().double()
    with pytest.raises(ValueError) as e:
        _ = ProductManifold((p1, (10,)), (p2, (12,)))
    assert e.match("Not all manifold share the same dtype")
github geoopt / geoopt / tests / test_random.py View on Github external
def test_product():
    manifold = geoopt.ProductManifold(
        (geoopt.Sphere(), 10),
        (geoopt.PoincareBall(), 3),
        (geoopt.Stiefel(), (20, 2)),
        (geoopt.Euclidean(), 43),
    )
    sample = manifold.random(20, manifold.n_elements)
    manifold.assert_check_point_on_manifold(sample)
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def poincare_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.PoincareBall]
    ex = torch.randn(*shape, dtype=torch.float64) / 3
    ev = torch.randn(*shape, dtype=torch.float64) / 3
    x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
    ex = x.clone()
    v = ev.clone()
    manifold = geoopt.PoincareBall().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_product_manifold.py View on Github external
def test_dtype_checked_properly():
    p1 = PoincareBall()
    p2 = PoincareBall().double()
    with pytest.raises(ValueError) as e:
        _ = ProductManifold((p1, (10,)), (p2, (12,)))
    assert e.match("Not all manifold share the same dtype")
github geoopt / geoopt / tests / test_random.py View on Github external
def test_random_Poincare():
    manifold = geoopt.PoincareBall()
    point = manifold.random_normal(3, 10, 10)
    manifold.assert_check_point_on_manifold(point)
    assert point.manifold is manifold
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
    # + 1 case without stiefel
    torch.manual_seed(42)
    ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    manifolds = [
        geoopt.Sphere(),
        geoopt.PoincareBall(),
        # geoopt.Stiefel(),
        geoopt.Euclidean(),
    ]
    x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
    v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]

    product_manifold = geoopt.ProductManifold(
        *((manifolds[i], ex[i].shape) for i in range(len(ex)))
    )

    yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
github geoopt / geoopt / tests / test_random.py View on Github external
def test_fails_Poincare():
    with pytest.raises(ValueError):
        manifold = geoopt.PoincareBall()
        manifold.random_normal(())