How to use the numpyro.distributions.Categorical function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent
github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def test_chain_smoke(chain_method, compile_args):
    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000,))
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args)
    mcmc.warmup(random.PRNGKey(0), data)
    mcmc.run(random.PRNGKey(1), data)
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    @jit
    def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob):
        kernel = kernel_cls(model, step_size=step_size, trajectory_length=trajectory_length,
                            target_accept_prob=target_accept_prob)
        mcmc = MCMC(kernel, warmup_steps, num_samples, num_chains=2, chain_method=chain_method,
                    progress_bar=False)
        mcmc.run(rng_key, data)
        return mcmc.get_samples()

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    samples = get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent
github pyro-ppl / numpyro / examples / hmm.py View on Github external
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = np.ones(num_categories)
    emission_prior = np.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
                                                              sample_shape=(num_categories,))
    emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
                                                          sample_shape=(num_categories,))

    start_prob = np.repeat(1. / num_categories, num_categories)
    categories, words = [], []
    for t in range(num_supervised_data + num_unsupervised_data):
        rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
        if t == 0 or t == num_supervised_data:
            category = dist.Categorical(start_prob).sample(key=rng_key_transition)
        else:
            category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
        word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
        categories.append(category)
        words.append(word)

    # split into supervised data and unsupervised data
    categories, words = np.stack(categories), np.stack(words)
    supervised_categories = categories[:num_supervised_data]
    supervised_words = words[:num_supervised_data]
    unsupervised_words = words[num_supervised_data:]
    return (transition_prior, emission_prior, transition_prob, emission_prob,
            supervised_categories, supervised_words, unsupervised_words)
github pyro-ppl / numpyro / examples / hmm.py View on Github external
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]
    transition_prob = numpyro.sample('transition_prob', dist.Dirichlet(
        np.broadcast_to(transition_prior, (num_categories, num_categories))))
    emission_prob = numpyro.sample('emission_prob', dist.Dirichlet(
        np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    numpyro.sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]),
                   obs=supervised_categories[1:])
    numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]),
                   obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    # inject log_prob to potential function
    numpyro.factor('forward_log_prob', log_prob)
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
tones_plate = pyro_plate("tones", data_dim, dim=-1)
    with pyro_plate("sequences", num_sequences, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro_markov(range(max_length)):
            with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
                probs_ww = probs_w[w]
                probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
                w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
                                infer={"enumerate": "parallel"})
                logging.info(f"w[{t}]: {w.shape}")

                probs_xx = probs_x[x]
                probs_xx = np.broadcast_to(probs_xx, probs_xx.shape[:-3] + (num_sequences, 1) + probs_xx.shape[-1:])
                x = pyro_sample("x_{}".format(t), dist.Categorical(probs_xx),
                                infer={"enumerate": "parallel"})
                logging.info(f"x[{t}]: {x.shape}")

                with tones_plate as tones:
                    probs_ywx = probs_y[w, x, tones]
                    probs_ywx = np.broadcast_to(
                        probs_ywx, probs_ywx.shape[:-2] + (num_sequences,) + probs_ywx.shape[-1:])
                    pyro_sample("y_{}".format(t), dist.Bernoulli(probs_ywx),
                                obs=sequences[batch, t])
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
.to_event(1))
        probs_y_shape = (hidden_dim, hidden_dim, data_dim)
        probs_y = pyro_sample("probs_y",
                              dist.Beta(np.full(probs_y_shape, 0.1),
                                        np.full(probs_y_shape, 0.9))
                                  .to_event(len(probs_y_shape)))

    tones_plate = pyro_plate("tones", data_dim, dim=-1)
    with pyro_plate("sequences", num_sequences, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro_markov(range(max_length)):
            with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
                probs_ww = probs_w[w]
                probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
                w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
                                infer={"enumerate": "parallel"})
                logging.info(f"w[{t}]: {w.shape}")

                probs_xx = probs_x[x]
                probs_xx = np.broadcast_to(probs_xx, probs_xx.shape[:-3] + (num_sequences, 1) + probs_xx.shape[-1:])
                x = pyro_sample("x_{}".format(t), dist.Categorical(probs_xx),
                                infer={"enumerate": "parallel"})
                logging.info(f"x[{t}]: {x.shape}")

                with tones_plate as tones:
                    probs_ywx = probs_y[w, x, tones]
                    probs_ywx = np.broadcast_to(
                        probs_ywx, probs_ywx.shape[:-2] + (num_sequences,) + probs_ywx.shape[-1:])
                    pyro_sample("y_{}".format(t), dist.Bernoulli(probs_ywx),
                                obs=sequences[batch, t])