How to use the numpyro.distributions.Beta function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def model(data=None):
        beta = numpyro.sample("beta", dist.Beta(np.ones(2), np.ones(2)))
        with numpyro.plate("plate", N, dim=-2):
            numpyro.sample("obs", dist.Bernoulli(beta), obs=data)
github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
github pyro-ppl / numpyro / test / test_distributions.py View on Github external
class _ImproperWrapper(dist.ImproperUniform):
    def sample(self, key, sample_shape=()):
        transform = biject_to(self.support)
        prototype_value = jnp.zeros(self.event_shape)
        unconstrained_event_shape = jnp.shape(transform.inv(prototype_value))
        shape = sample_shape + self.batch_shape + unconstrained_event_shape
        unconstrained_samples = random.uniform(key, shape,
                                               minval=-2,
                                               maxval=2)
        return transform(unconstrained_samples)


_DIST_MAP = {
    dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
    dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
    dist.Beta: lambda con1, con0: osp.beta(con1, con0),
    dist.BinomialProbs: lambda probs, total_count: osp.binom(n=total_count, p=probs),
    dist.BinomialLogits: lambda logits, total_count: osp.binom(n=total_count, p=_to_probs_bernoulli(logits)),
    dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
    dist.Chi2: lambda df: osp.chi2(df),
    dist.Dirichlet: lambda conc: osp.dirichlet(conc),
    dist.Exponential: lambda rate: osp.expon(scale=jnp.reciprocal(rate)),
    dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1. / rate),
    dist.Gumbel: lambda loc, scale: osp.gumbel_r(loc=loc, scale=scale),
    dist.HalfCauchy: lambda scale: osp.halfcauchy(scale=scale),
    dist.HalfNormal: lambda scale: osp.halfnorm(scale=scale),
    dist.InverseGamma: lambda conc, rate: osp.invgamma(conc, scale=rate),
    dist.Laplace: lambda loc, scale: osp.laplace(loc=loc, scale=scale),
    dist.LogNormal: lambda loc, scale: osp.lognorm(s=scale, scale=jnp.exp(loc)),
    dist.MultinomialProbs: lambda probs, total_count: osp.multinomial(n=total_count, p=probs),
    dist.MultinomialLogits: lambda logits, total_count: osp.multinomial(n=total_count,
                                                                        p=_to_probs_multinom(logits)),
github pyro-ppl / numpyro / test / contrib / test_funsor.py View on Github external
def model(data):
        y_prob = numpyro.sample("y_prob", dist.Beta(1., 1.))
        with numpyro.plate("data", data.shape[0]):
            y = numpyro.sample("y", dist.Bernoulli(y_prob))
            z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
            numpyro.sample("obs", dist.Normal(2. * z, 1.), obs=data)
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
def model_0(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with numpyro_mask(mask_array=include_prior):
        probs_x = pyro_sample("probs_x",
                              dist.Dirichlet(0.9 * np.eye(args.hidden_dim) + 0.1)
                                  .to_event(1))
        probs_y = pyro_sample("probs_y",
                              # the parameter expansion here is unfortunate, and
                              # necessitated by the fact that NumPyro allows some
                              # batch dimensions that are not plate or enum dims
                              dist.Beta(0.1 * np.ones((args.hidden_dim, data_dim)),
                                        0.9 * np.ones((args.hidden_dim, data_dim))
                                        ).to_event(2))

    tones_plate = pyro_plate("tones", data_dim, dim=-1)
    for i in pyro_plate("sequences", len(sequences)):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro_markov(range(length)):
            x = pyro_sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            logging.info(f"x[{i}, {t}]: {x.shape}")
            with tones_plate:
                pyro_sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
def model_4(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim ** 0.5)  # split between w and x
    with numpyro_mask(mask_array=include_prior):
        probs_w = pyro_sample("probs_w",
                              dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_x = pyro_sample("probs_x",
                              dist.Dirichlet(
                                  np.broadcast_to(0.9 * np.eye(hidden_dim) + 0.1,
                                                  (hidden_dim, hidden_dim, hidden_dim)))
                                  .to_event(2))

        probs_y_shape = (hidden_dim, hidden_dim, data_dim)
        probs_y = pyro_sample("probs_y",
                              dist.Beta(np.full(probs_y_shape, 0.1),
                                        np.full(probs_y_shape, 0.9))
                                  .to_event(len(probs_y_shape)))

    tones_plate = pyro_plate("tones", data_dim, dim=-1)
    with pyro_plate("sequences", num_sequences, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = np.array(0)
        for t in pyro_markov(range(max_length)):
            with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
                probs_ww = probs_w[w]
                probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
                w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
                                infer={"enumerate": "parallel"})
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
def model_1(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with numpyro_mask(mask_array=include_prior):
        probs_x = pyro_sample("probs_x",
                              dist.Dirichlet(0.9 * np.eye(args.hidden_dim) + 0.1)
                                  .to_event(1))
        probs_y = pyro_sample("probs_y",
                              # the parameter expansion here is unfortunate, and
                              # necessitated by the fact that NumPyro allows some
                              # batch dimensions that are not plate or enum dims
                              dist.Beta(0.1 * np.ones((args.hidden_dim, data_dim)),
                                        0.9 * np.ones((args.hidden_dim, data_dim))
                                        ).to_event(2))

    tones_plate = pyro_plate("tones", data_dim, dim=-1)
    with pyro_plate("sequences", num_sequences, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        for t in pyro_markov(range(max_length)):
            with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
                probs_xx = probs_x[x]
                probs_xx = np.broadcast_to(probs_xx, probs_xx.shape[:-3] + (num_sequences, 1) + probs_xx.shape[-1:])
                x = pyro_sample("x_{}".format(t), dist.Categorical(probs_xx),
                                infer={"enumerate": "parallel"})
                logging.info(f"x[{t}]: {x.shape}")
                with tones_plate:
                    probs_yx = probs_y[x.squeeze(-1)]
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
def model_3(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim ** 0.5)  # split between w and x
    with numpyro_mask(mask_array=include_prior):
        probs_w = pyro_sample("probs_w",
                              dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_x = pyro_sample("probs_x",
                              dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_y_shape = (hidden_dim, hidden_dim, data_dim)
        probs_y = pyro_sample("probs_y",
                              dist.Beta(np.full(probs_y_shape, 0.1),
                                        np.full(probs_y_shape, 0.9))
                                  .to_event(len(probs_y_shape)))

    tones_plate = pyro_plate("tones", data_dim, dim=-1)
    with pyro_plate("sequences", num_sequences, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro_markov(range(max_length)):
            with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
                probs_ww = probs_w[w]
                probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
                w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
                                infer={"enumerate": "parallel"})
                logging.info(f"w[{t}]: {w.shape}")

                probs_xx = probs_x[x]