How to use the geoopt.manifolds function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / tests / test_adam.py View on Github external
def test_adam_poincare():
    torch.manual_seed(44)
    ideal = torch.tensor([0.5, 0.5])
    start = torch.randn(2) / 2
    start = geoopt.manifolds.poincare.math.expmap0(start, c=1.0)
    start = geoopt.ManifoldParameter(start, manifold=geoopt.PoincareBall())

    def closure():
        optim.zero_grad()
        loss = geoopt.manifolds.poincare.math.dist(start, ideal) ** 2
        loss.backward()
        return loss.item()

    optim = geoopt.optim.RiemannianAdam([start], lr=1e-2)

    for _ in range(2000):
        optim.step(closure)
    np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
github geoopt / geoopt / tests / test_manifold.py View on Github external
geoopt.manifolds.SphereSubspaceIntersection: functools.partial(
        pymanopt.manifolds.SphereSubspaceIntersection,
        U=np.random.RandomState(42).randn(10, 3),
    ),
    geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial(
        pymanopt.manifolds.SphereSubspaceComplementIntersection,
        U=np.random.RandomState(42).randn(10, 3),
    ),
}

# shapes to verify unary element implementation
shapes = {
    geoopt.manifolds.PoincareBall: (3,),
    geoopt.manifolds.EuclideanStiefel: (10, 5),
    geoopt.manifolds.CanonicalStiefel: (10, 5),
    geoopt.manifolds.Euclidean: (1,),
    geoopt.manifolds.Sphere: (10,),
    geoopt.manifolds.SphereSubspaceIntersection: (10,),
    geoopt.manifolds.SphereSubspaceComplementIntersection: (10,),
}

UnaryCase = collections.namedtuple(
    "UnaryCase", "shape,x,ex,v,ev,manifold,manopt_manifold"
)


@pytest.fixture()
def unary_case(manifold):
    shape = shapes[type(manifold)]
    np.random.seed(42)
    torch.manual_seed(43)
    if type(manifold) in mannopt:
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def poincare_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.PoincareBall]
    ex = torch.randn(*shape, dtype=torch.float64) / 3
    ev = torch.randn(*shape, dtype=torch.float64) / 3
    x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
    ex = x.clone()
    v = ev.clone()
    manifold = geoopt.PoincareBall().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_manifold.py View on Github external
def test_transport(unary_case, t):
    if unary_case.manopt_manifold is None:
        pytest.skip("pymanopt does not have {}".format(unary_case.manifold))
    if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel):
        pytest.skip("pymanopt uses euclidean Stiefel")
    x = unary_case.x
    v = unary_case.v

    y = x.retr(v, t=t)

    u = x.transp(v, u=v, t=t)

    u_star = unary_case.manopt_manifold.transp(x.numpy(), y.numpy(), v.numpy())

    np.testing.assert_allclose(u, u_star)
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def euclidean_stiefel_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.EuclideanStiefel]
    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    u, _, v = torch.svd(ex)
    x = u @ v.t()
    nonsym = x.t() @ ev
    v = ev - x @ (nonsym + nonsym.t()) / 2

    manifold = geoopt.manifolds.EuclideanStiefel()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.manifolds.EuclideanStiefelExact()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_manifold.py View on Github external
        functools.partial(geoopt.manifolds.Stiefel, canonical=False),
        functools.partial(geoopt.manifolds.Stiefel, canonical=True),
        geoopt.manifolds.PoincareBall,
        geoopt.manifolds.Euclidean,
        geoopt.manifolds.Sphere,
        functools.partial(
            geoopt.manifolds.SphereSubspaceIntersection,
            torch.from_numpy(np.random.RandomState(42).randn(10, 3)),
        ),
        functools.partial(
            geoopt.manifolds.SphereSubspaceComplementIntersection,
            torch.from_numpy(np.random.RandomState(42).randn(10, 3)),
        ),
    ],
)
def manifold(request, retraction_order):
    man = request.param()
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def euclidean_stiefel_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.EuclideanStiefel]
    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    u, _, v = torch.svd(ex)
    x = u @ v.t()
    nonsym = x.t() @ ev
    v = ev - x @ (nonsym + nonsym.t()) / 2

    manifold = geoopt.manifolds.EuclideanStiefel()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.manifolds.EuclideanStiefelExact()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
@pytest.fixture(autouse=True, params=[1, 2, 3, 4, 5])
def seed(request):
    torch.manual_seed(request.param)
    yield


@pytest.fixture(autouse=True, params=[torch.float64], ids=lambda t: str(t))
def use_floatX(request):
    dtype_old = torch.get_default_dtype()
    torch.set_default_dtype(request.param)
    yield request.param
    torch.set_default_dtype(dtype_old)


manifold_shapes = {
    geoopt.manifolds.PoincareBall: (3,),
    geoopt.manifolds.EuclideanStiefel: (10, 5),
    geoopt.manifolds.CanonicalStiefel: (10, 5),
    geoopt.manifolds.Euclidean: (10,),
    geoopt.manifolds.Sphere: (10,),
    geoopt.manifolds.SphereExact: (10,),
    geoopt.manifolds.ProductManifold: (10 + 3 + 6 + 1,),
}


UnaryCase = collections.namedtuple("UnaryCase", "shape,x,ex,v,ev,manifold")


def canonical_stiefel_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.CanonicalStiefel]
    ex = torch.randn(*shape)
github geoopt / geoopt / tests / test_manifold.py View on Github external
],
)
def manifold(request, retraction_order):
    man = request.param()
    try:
        return man.set_default_order(retraction_order).double()
    except ValueError:
        pytest.skip("not supported retraction order for {}".format(man))


mannopt = {
    geoopt.manifolds.EuclideanStiefel: pymanopt.manifolds.Stiefel,
    geoopt.manifolds.CanonicalStiefel: pymanopt.manifolds.Stiefel,
    geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean,
    geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere,
    geoopt.manifolds.SphereSubspaceIntersection: functools.partial(
        pymanopt.manifolds.SphereSubspaceIntersection,
        U=np.random.RandomState(42).randn(10, 3),
    ),
    geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial(
        pymanopt.manifolds.SphereSubspaceComplementIntersection,
        U=np.random.RandomState(42).randn(10, 3),
    ),
}

# shapes to verify unary element implementation
shapes = {
    geoopt.manifolds.PoincareBall: (3,),
    geoopt.manifolds.EuclideanStiefel: (10, 5),
    geoopt.manifolds.CanonicalStiefel: (10, 5),
    geoopt.manifolds.Euclidean: (1,),
    geoopt.manifolds.Sphere: (10,),
github geoopt / geoopt / tests / test_manifold.py View on Github external
geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean,
    geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere,
    geoopt.manifolds.SphereSubspaceIntersection: functools.partial(
        pymanopt.manifolds.SphereSubspaceIntersection,
        U=np.random.RandomState(42).randn(10, 3),
    ),
    geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial(
        pymanopt.manifolds.SphereSubspaceComplementIntersection,
        U=np.random.RandomState(42).randn(10, 3),
    ),
}

# shapes to verify unary element implementation
shapes = {
    geoopt.manifolds.PoincareBall: (3,),
    geoopt.manifolds.EuclideanStiefel: (10, 5),
    geoopt.manifolds.CanonicalStiefel: (10, 5),
    geoopt.manifolds.Euclidean: (1,),
    geoopt.manifolds.Sphere: (10,),
    geoopt.manifolds.SphereSubspaceIntersection: (10,),
    geoopt.manifolds.SphereSubspaceComplementIntersection: (10,),
}

UnaryCase = collections.namedtuple(
    "UnaryCase", "shape,x,ex,v,ev,manifold,manopt_manifold"
)


@pytest.fixture()
def unary_case(manifold):
    shape = shapes[type(manifold)]
    np.random.seed(42)