How to use the geoopt.Scaled function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / tests / test_scaling.py View on Github external
def test_scale_poincare():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    v = torch.arange(10).float() / 10
    np.testing.assert_allclose(
        ball.dist0(ball.expmap0(v)).item(),
        sball.dist0(sball.expmap0(v)).item(),
        atol=1e-5,
    )
    np.testing.assert_allclose(
        sball.dist0(sball.expmap0(v)).item(),
        sball.norm(torch.zeros_like(v), v),
        atol=1e-5,
    )
github geoopt / geoopt / tests / test_scaling.py View on Github external
def test_rescaling_methods_accessible():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v0 = torch.arange(10).float() / 10
    v1 = -torch.arange(10).float() / 10
    rsball.geodesic(0.5, v0, v1)
github geoopt / geoopt / tests / test_scaling.py View on Github external
def test_scaling_compensates():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v = torch.arange(10).float() / 10
    np.testing.assert_allclose(ball.expmap0(v), rsball.expmap0(v))
github geoopt / geoopt / tests / test_scaling.py View on Github external
def test_tensor_is_attached():
    m1 = geoopt.Euclidean()
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    p = m1.random(())
    assert m1.is_attached(p)
github geoopt / geoopt / tests / test_manifold_basic.py View on Github external
def unary_case(unary_case_base, scaled):
    if scaled:
        return unary_case_base._replace(
            manifold=geoopt.Scaled(unary_case_base.manifold, 2)
        )
    else:
        return unary_case_base
github geoopt / geoopt / tests / test_scaling.py View on Github external
def test_scaling_compensates():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v = torch.arange(10).float() / 10
    np.testing.assert_allclose(ball.expmap0(v), rsball.expmap0(v))
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_ismanifold():
    m1 = geoopt.Euclidean()
    assert geoopt.ismanifold(m1, geoopt.Euclidean)
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    assert geoopt.ismanifold(m1, geoopt.Euclidean)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, int)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, 1)

    assert not geoopt.ismanifold(1, geoopt.Euclidean)
github geoopt / geoopt / tests / test_scaling.py View on Github external
def test_rescaling_methods_accessible():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v0 = torch.arange(10).float() / 10
    v1 = -torch.arange(10).float() / 10
    rsball.geodesic(0.5, v0, v1)
github geoopt / geoopt / tests / test_utils.py View on Github external
def test_ismanifold():
    m1 = geoopt.Euclidean()
    assert geoopt.ismanifold(m1, geoopt.Euclidean)
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    assert geoopt.ismanifold(m1, geoopt.Euclidean)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, int)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, 1)

    assert not geoopt.ismanifold(1, geoopt.Euclidean)
github geoopt / geoopt / geoopt / utils.py View on Github external
check if a given manifold is compatible with cls API
    cls : type
        manifold type

    Returns
    -------
    bool
        comparison result
    """
    if not issubclass(cls, geoopt.manifolds.Manifold):
        raise TypeError("`cls` should be a subclass of geoopt.manifolds.Manifold")
    if not isinstance(instance, geoopt.manifolds.Manifold):
        return False
    else:
        # this is the case to care about, Scaled class is a proxy, but fails instance checks
        while isinstance(instance, geoopt.Scaled):
            instance = instance.base
        return isinstance(instance, cls)