How to use the numpyro.distributions.Uniform function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_svi.py View on Github external
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    x = 2.
    guide = substitute(guide, param_map={'x': x})
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        alpha = 1 / jnp.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.TransformedDistribution(
                dist.Uniform(0, 1).mask(False),
                AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
github pyro-ppl / numpyro / test / test_autoguide.py View on Github external
def expected_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
github pyro-ppl / numpyro / test / test_autoguide.py View on Github external
def actual_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
github pyro-ppl / numpyro / numpyro / infer / initialization.py View on Github external
rng_key, subkey = random.split(rng_key)

        # this is used to interpret the changes of event_shape in
        # domain and codomain spaces
        try:
            prototype_value = site['fn'].sample(subkey, sample_shape=())
        except NotImplementedError:
            # XXX: this works for ImproperUniform prior,
            # we can't use this logic for general priors
            # because some distributions such as TransformedDistribution might
            # have wrong event_shape.
            prototype_value = jnp.full(site['fn'].shape(), jnp.nan)

        transform = biject_to(site['fn'].support)
        unconstrained_shape = jnp.shape(transform.inv(prototype_value))
        unconstrained_samples = dist.Uniform(-radius, radius).sample(
            rng_key, sample_shape=sample_shape + unconstrained_shape)
        return transform(unconstrained_samples)
github pyro-ppl / numpyro / examples / baseball.py View on Github external
def partially_pooled(at_bats, hits=None):
    r"""
    Number of hits has a Binomial distribution with independent
    probability of success, $\phi_i$. Each $\phi_i$ follows a Beta
    distribution with concentration parameters $c_1$ and $c_2$, where
    $c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$,
    and $kappa ~ Pareto(1, 1.5)$.

    :param (jnp.DeviceArray) at_bats: Number of at bats for each player.
    :param (jnp.DeviceArray) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    m = numpyro.sample("m", dist.Uniform(0, 1))
    kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        phi_prior = dist.Beta(m * kappa, (1 - m) * kappa)
        phi = numpyro.sample("phi", phi_prior)
        return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)