How to use the numpyro.distributions.Normal function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_svi.py View on Github external
d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
        numpyro.sample('y', dist.Normal(c, d), obs=obs)

    adam = optim.Adam(0.01)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['c'], c_init)
    assert_allclose(params['d'], d_init)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs)
    # not so precisely because we do transform / inverse transform stuffs
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(a=None, b=None, z=None):
        int_term = numpyro.sample('a', dist.Normal(0., 0.2))
        x_term, y_term = 0., 0.
        if a is not None:
            x = numpyro.sample('x', dist.HalfNormal(0.5))
            x_term = a * x
        if b is not None:
            y = numpyro.sample('y', dist.HalfNormal(0.5))
            y_term = b * y
        sigma = numpyro.sample('sigma', dist.Exponential(1.))
        mu = int_term + x_term + y_term
        numpyro.sample('obs', dist.Normal(mu, sigma), obs=z)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False))
        std = numpyro.sample('std', dist.ImproperUniform(dist.constraints.positive, (), ()))
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def schools_model():
        mu = numpyro.sample('mu', dist.Normal(0, 5))
        tau = numpyro.sample('tau', dist.HalfCauchy(5))
        theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'],))
        numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def model(data):
        loc = numpyro.sample('loc', dist.Normal(0., 1.))
        numpyro.sample('obs', dist.Normal(loc, 1), obs=data)
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def _sample():
        x = numpyro.sample('x', dist.Normal(0., 1.))
        y = numpyro.sample('y', dist.Normal(1., 2.))
        return jnp.stack([x, y])
github pyro-ppl / numpyro / test / test_distributions.py View on Github external
def test_mask(batch_shape, event_shape, mask_shape):
    jax_dist = dist.Normal().expand(batch_shape + event_shape).to_event(len(event_shape))
    mask = dist.Bernoulli(0.5).sample(random.PRNGKey(0), mask_shape)
    if mask_shape == ():
        mask = bool(mask)
    samples = jax_dist.sample(random.PRNGKey(1))
    actual = jax_dist.mask(mask).log_prob(samples)
    assert_allclose(actual != 0, jnp.broadcast_to(mask, lax.broadcast_shapes(batch_shape, mask_shape)))
github pyro-ppl / numpyro / examples / funnel.py View on Github external
def reparam_model(dim=10):
    y = numpyro.sample('y', dist.Normal(0, 3))
    with numpyro.handlers.reparam(config={'x': LocScaleReparam(0)}):
        numpyro.sample('x', dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))