How to use the numpyro.distributions function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False))
        std = numpyro.sample('std', dist.ImproperUniform(dist.constraints.positive, (), ()))
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)
github pyro-ppl / numpyro / examples / vae.py View on Github external
def reconstruct_img(epoch, rng_key):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray')
        rng_key_binarize, rng_key_sample = random.split(rng_key)
        test_sample = binarize(rng_key_binarize, img)
        params = svi.get_params(svi_state)
        z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1]))
        z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
        img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray')
github pyro-ppl / numpyro / numpyro / primitives.py View on Github external
def factor(name, log_factor):
    """
    Factor statement to add arbitrary log probability factor to a
    probabilistic model.

    :param str name: Name of the trivial sample.
    :param numpy.ndarray log_factor: A possibly batched log probability factor.
    """
    unit_dist = numpyro.distributions.distribution.Unit(log_factor)
    unit_value = unit_dist.sample(None)
    sample(name, unit_dist, obs=unit_value)
github pyro-ppl / numpyro / numpyro / compat / infer.py View on Github external
def step(self, *args, rng_key=None, **kwargs):
        if self.svi_state is None:
            if rng_key is None:
                rng_key = numpyro.sample('svi.init', dist.PRNGIdentity())
            self.svi_state = self.init(rng_key, *args, **kwargs)
        try:
            self.svi_state, loss = jit(self.update)(self.svi_state, *args, **kwargs)
        except TypeError as e:
            if 'not a valid JAX type' in str(e):
                raise TypeError('NumPyro backend requires args, kwargs to be arrays or tuples, '
                                'dicts of arrays.')
            else:
                raise e
        params = jit(super(SVI, self).get_params)(self.svi_state)
        get_param_store().update(params)
        return loss
github deepppl / deepppl / deepppl / utils / utils.py View on Github external
def build_hooks(npyro=False):
    if npyro:
        d = np_dist
        const = np_constraints
        provider = jnp
    else:
        d = dist
        const = constraints
        provider = torch
        
    def categorical_logits(logits):
        return d.Categorical(logits=logits)


    def bernoulli_logit(logits):
        return d.Bernoulli(logits=logits)


    def binomial_logit(n, logits):
github pyro-ppl / numpyro / numpyro / infer / autoguide.py View on Github external
def get_base_dist(self):
        return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)
github pyro-ppl / numpyro / numpyro / infer / util.py View on Github external
"""
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = jnp.array(0.)
    for site in model_trace.values():
        if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity):
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace
github pyro-ppl / numpyro / numpyro / infer / reparam.py View on Github external
def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        fn, batch_shape = self._unexpand(fn)
        assert isinstance(fn, dist.TransformedDistribution)

        # Draw noise from the base distribution.
        # We need to make sure that we have the same batch_shape
        reinterpreted_batch_ndims = fn.event_dim - fn.base_dist.event_dim
        x = numpyro.sample("{}_base".format(name),
                           fn.base_dist.to_event(reinterpreted_batch_ndims).expand(batch_shape))

        # Differentiably transform.
        for t in fn.transforms:
            x = t(x)

        # Simulate a pyro.deterministic() site.
        return None, x
github pyro-ppl / numpyro / examples / neutra.py View on Github external
import numpyro
from numpyro import optim
from numpyro.contrib.autoguide import AutoContinuousELBO, AutoIAFNormal
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, SVI
from numpyro.infer.util import initialize_model, transformed_potential_energy

# TODO: remove when the issue https://github.com/google/jax/issues/939 is fixed upstream
# The behaviour when training guide under fast math mode is unstable.
os.environ["XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false"


class DualMoonDistribution(dist.Distribution):
    support = constraints.real_vector

    def __init__(self):
        super(DualMoonDistribution, self).__init__(event_shape=(2,))

    def sample(self, key, sample_shape=()):
        # it is enough to return an arbitrary sample with correct shape
        return np.zeros(sample_shape + self.event_shape)

    def log_prob(self, x):
        term1 = 0.5 * ((np.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2
        term2 = -0.5 * ((x[..., :1] + np.array([-2., 2.])) / 0.6) ** 2
        pe = term1 - logsumexp(term2, axis=-1)
        return -pe
github pyro-ppl / numpyro / examples / bnn.py View on Github external
def model(X, Y, D_H):

    D_X, D_Y = X.shape[1], 1

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H))))  # D_X D_H
    z1 = nonlin(np.matmul(X, w1))   # N D_H  <= first layer of activations

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(np.zeros((D_H, D_H)), np.ones((D_H, D_H))))  # D_H D_H
    z2 = nonlin(np.matmul(z1, w2))  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(np.zeros((D_H, D_Y)), np.ones((D_H, D_Y))))  # D_H D_Y
    z3 = np.matmul(z2, w3)  # N D_Y  <= output of the neural network

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / np.sqrt(prec_obs)

    # observe data
    numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)