How to use the geoopt.Euclidean function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / tests / test_tensor_api.py View on Github external
def test_tensor_is_attached():
    m1 = geoopt.Euclidean()
    p = m1.random(())
    assert m1.is_attached(p)
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_ismanifold():
    m1 = geoopt.Euclidean()
    assert geoopt.ismanifold(m1, geoopt.Euclidean)
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    assert geoopt.ismanifold(m1, geoopt.Euclidean)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, int)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, 1)

    assert not geoopt.ismanifold(1, geoopt.Euclidean)
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
    # + 1 case without stiefel
    torch.manual_seed(42)
    ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    manifolds = [
        geoopt.Sphere(),
        geoopt.PoincareBall(),
        # geoopt.Stiefel(),
        geoopt.Euclidean(),
    ]
    x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
    v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]

    product_manifold = geoopt.ProductManifold(
        *((manifolds[i], ex[i].shape) for i in range(len(ex)))
    )

    yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
github geoopt / geoopt / tests / test_scaling.py View on Github external
def test_tensor_is_attached():
    m1 = geoopt.Euclidean()
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    p = m1.random(())
    assert m1.is_attached(p)
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_ismanifold():
    m1 = geoopt.Euclidean()
    assert geoopt.ismanifold(m1, geoopt.Euclidean)
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    assert geoopt.ismanifold(m1, geoopt.Euclidean)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, int)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, 1)

    assert not geoopt.ismanifold(1, geoopt.Euclidean)
github geoopt / geoopt / tests / test_product_manifold.py View on Github external
def test_inner_product():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)), (Euclidean(), ()))
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(5, 3, 2),
        Euclidean().random_normal(5),
    ]
    tensor = pman.pack_point(*point)
    tangent = torch.randn_like(tensor)
    tangent = pman.proju(tensor, tangent)

    inner = pman.inner(tensor, tangent)
    assert inner.shape == (5,)
    inner_kd = pman.inner(tensor, tangent, keepdim=True)
    assert inner_kd.shape == (5, 1)
github geoopt / geoopt / tests / test_random.py View on Github external
def test_product():
    manifold = geoopt.ProductManifold(
        (geoopt.Sphere(), 10),
        (geoopt.PoincareBall(), 3),
        (geoopt.Stiefel(), (20, 2)),
        (geoopt.Euclidean(), 43),
    )
    sample = manifold.random(20, manifold.n_elements)
    manifold.assert_check_point_on_manifold(sample)
github geoopt / geoopt / tests / test_product_manifold.py View on Github external
def test_from_point_checks_shapes():
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(3, 3, 2),
        Euclidean().random_normal(5),
    ]
    pman = ProductManifold.from_point(*point)
    assert pman.n_elements == (5 * 10 + 3 * 3 * 2 + 5 * 1)
    with pytest.raises(ValueError) as e:
        _ = ProductManifold.from_point(*point, batch_dims=1)
    assert e.match("Not all parts have same batch shape")
github geoopt / geoopt / tests / test_random.py View on Github external
def test_fails_Euclidean():
    with pytest.raises(ValueError):
        manifold = geoopt.Euclidean(ndim=1)
        manifold.random_normal(())