How to use geoopt - 10 common examples

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / tests / test_product_manifold.py View on Github external
def test_init():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)))
    assert pman.n_elements == (10 + 3 * 2)
    assert pman.name == "(Sphere)x(Sphere)"
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def sphere_subspace_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    subspace = torch.rand(shape[-1], 2, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(subspace)
    P = Q @ Q.t()

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def canonical_stiefel_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.CanonicalStiefel]
    ex = torch.randn(*shape)
    ev = torch.randn(*shape)
    u, _, v = torch.svd(ex)
    x = u @ v.t()
    v = ev - x @ ev.t() @ x
    manifold = geoopt.manifolds.CanonicalStiefel()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_tensor_api.py View on Github external
def test_compare_manifolds():
    m1 = geoopt.Euclidean()
    m2 = geoopt.Euclidean(ndim=1)
    tensor = geoopt.ManifoldTensor(10, manifold=m1)
    with pytest.raises(ValueError) as e:
        _ = geoopt.ManifoldParameter(tensor, manifold=m2)
    assert e.match("Manifolds do not match")
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def euclidean_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Euclidean]
    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = ex.clone()
    v = ev.clone()
    manifold = geoopt.Euclidean(ndim=1)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def poincare_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.PoincareBall]
    ex = torch.randn(*shape, dtype=torch.float64) / 3
    ev = torch.randn(*shape, dtype=torch.float64) / 3
    x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
    ex = x.clone()
    v = ev.clone()
    manifold = geoopt.PoincareBall().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def sphere_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = ex / torch.norm(ex)
    v = ev - (x @ ev) * x

    manifold = geoopt.Sphere()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_tensor_api.py View on Github external
def test_tensor_is_attached():
    m1 = geoopt.Euclidean()
    p = m1.random(())
    assert m1.is_attached(p)
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_ismanifold():
    m1 = geoopt.Euclidean()
    assert geoopt.ismanifold(m1, geoopt.Euclidean)
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    assert geoopt.ismanifold(m1, geoopt.Euclidean)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, int)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, 1)

    assert not geoopt.ismanifold(1, geoopt.Euclidean)
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
    # + 1 case without stiefel
    torch.manual_seed(42)
    ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    manifolds = [
        geoopt.Sphere(),
        geoopt.PoincareBall(),
        # geoopt.Stiefel(),
        geoopt.Euclidean(),
    ]
    x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
    v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]

    product_manifold = geoopt.ProductManifold(
        *((manifolds[i], ex[i].shape) for i in range(len(ex)))
    )

    yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )