How to use the numpyro.infer.NUTS function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_compile.py View on Github external
def test_mcmc_parallel_chain(deterministic):
    GLOBAL["count"] = 0
    mcmc = MCMC(NUTS(model), 100, 100, num_chains=2)
    mcmc.run(random.PRNGKey(0), deterministic=deterministic)
    mcmc.get_samples()

    if deterministic:
        assert GLOBAL["count"] == 4
    else:
        assert GLOBAL["count"] == 3
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def test_chain_smoke(chain_method, compile_args):
    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000,))
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args)
    mcmc.warmup(random.PRNGKey(0), data)
    mcmc.run(random.PRNGKey(1), data)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def test_empty_model(num_chains, chain_method, progress_bar):
    def model():
        pass

    mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=num_chains,
                chain_method=chain_method, progress_bar=progress_bar)
    mcmc.run(random.PRNGKey(0))
    assert mcmc.get_samples() == {}
github pyro-ppl / numpyro / test / contrib / test_control_flow.py View on Github external
mu1 = beta * mu0 + x1
            y1 = numpyro.sample('y', dist.Normal(mu1, r))
            numpyro.deterministic('y2', y1 * 2)
            return (x1, mu1), (x1, y1)

        mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
        y0 = numpyro.sample('y_0', dist.Normal(mu0, r))

        _, xy = scan(transition, (x0, mu0), jnp.arange(T))
        x, y = xy

        return jnp.append(x0, x), jnp.append(y0, y)

    T = 10
    num_samples = 100
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 100, num_samples)
    mcmc.run(jax.random.PRNGKey(0), T=T)
    assert set(mcmc.get_samples()) == {'x', 'y', 'y2', 'x_0', 'y_0'}
    mcmc.print_summary()

    samples = mcmc.get_samples()
    x = samples.pop('x')[0]  # take 1 sample of x
    # this tests for the composition of condition and substitute
    # this also tests if we can use `vmap` for predictive.
    future = 5
    predictive = Predictive(numpyro.handlers.condition(model, {'x': x}),
                            samples, return_sites=['x', 'y', 'y2'], parallel=True)
    result = predictive(jax.random.PRNGKey(1), T=T + future)
    expected_shape = (num_samples, T + future)
    assert result['x'].shape == expected_shape
    assert result['y'].shape == expected_shape
github pyro-ppl / numpyro / test / contrib / test_funsor.py View on Github external
trans_prob = transition[x]

    def _generate_data():
        transition_probs = np.random.rand(dim, dim)
        transition_probs = transition_probs / transition_probs.sum(-1, keepdims=True)
        emissions_loc = np.arange(dim)
        emissions_scale = 1.
        state = np.random.choice(3)
        obs = [np.random.normal(emissions_loc[state], emissions_scale)]
        for _ in range(num_steps - 1):
            state = np.random.choice(dim, p=transition_probs[state])
            obs.append(np.random.normal(emissions_loc[state], emissions_scale))
        return np.stack(obs)

    data = _generate_data()
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(0), data)
github pyro-ppl / numpyro / examples / covtype.py View on Github external
def benchmark_hmc(args, features, labels):
    step_size = np.sqrt(0.5 / features.shape[0])
    trajectory_length = step_size * args.num_steps
    rng_key = random.PRNGKey(1)
    start = time.time()
    kernel = NUTS(model, trajectory_length=trajectory_length)
    mcmc = MCMC(kernel, 0, args.num_samples)
    mcmc.run(rng_key, features, labels)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
github pyro-ppl / numpyro / benchmarks / hmm.py View on Github external
def numpyro_inference(data, args):
    rng_key = jax.random.PRNGKey(args.seed)
    kernel = numpyro.infer.NUTS(numpyro_model)
    mcmc = numpyro.infer.MCMC(kernel, args.num_warmup, args.num_samples,
                              num_chains=args.num_chains, progress_bar=not args.disable_progbar)
    tic = time.time()
    mcmc._compile(rng_key, *data, extra_fields=('num_steps',))
    print('MCMC (numpyro) compiling time:', time.time() - tic, '\n')
    tic = time.time()
    mcmc.warmup(rng_key, *data, extra_fields=('num_steps',))
    mcmc.num_samples = args.num_samples
    rng_key = mcmc._warmup_state.rng_key.copy()
    tic_run = time.time()
    mcmc.run(rng_key, *data, extra_fields=('num_steps',))
    mcmc._last_state.rng_key.copy()
    toc = time.time()
    mcmc.print_summary()
    print('\nMCMC (numpyro) elapsed time:', toc - tic)
    sampling_time = toc - tic_run
github pyro-ppl / numpyro / examples / baseball.py View on Github external
def run_inference(model, at_bats, hits, rng_key, args):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    elif args.algo == "SA":
        kernel = SA(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if (
                    "NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar) else True)
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
github pyro-ppl / numpyro / examples / cb.py View on Github external
def run_hmc(model, args, rng_key, X, Y, hypers):
    start = time.time()
    kernel = NUTS(model, max_tree_depth=args['mtd'])
    mcmc = MCMC(kernel, args['num_warmup'], args['num_samples'], num_chains=1,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y, hypers)
    #mcmc.print_summary()
    elapsed_time = time.time() - start

    samples = mcmc.get_samples()

    return samples, elapsed_time