How to use the numpyro.param function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_autoguide.py View on Github external
def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b))
github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def model(data):
        mean = numpyro.param('mean', 0.)
        std = numpyro.param('std', 1., constraint=constraints.positive)
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)
github pyro-ppl / numpyro / test / test_svi.py View on Github external
def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
        y = handlers.substitute(lambda: numpyro.param('y', None) * numpyro.param('x', None), {'y': x})()
        return x + y
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def model():
        x = numpyro.param('x', None)
        y = handlers.substitute(lambda: numpyro.param('y', None) * numpyro.param('x', None), {'y': x})()
        return x + y
github pyro-ppl / numpyro / numpyro / infer / autoguide.py View on Github external
def _get_posterior(self):
        loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
        scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix),
                                   jnp.identity(self.latent_dim) * self._init_scale,
                                   constraint=constraints.lower_cholesky)
        return dist.MultivariateNormal(loc, scale_tril=scale_tril)
github pyro-ppl / numpyro / examples / pairwise.py View on Github external
def bernoulli_guide(X, Y, hypers, method="direct", num_probes=4, cg_tol=0.001):
    S, sigma, P, N = hypers['expected_sparsity'], hypers['sigma'], X.shape[1], X.shape[0]

    phi = sigma * (S / np.sqrt(N)) / (P - S)

    eta1_loc = numpyro.param("eta1_loc", 0.25, constraint=constraints.positive)
    numpyro.sample("eta1", dist.Delta(eta1_loc))

    msq_loc = numpyro.param("msq_loc", 1.0, constraint=constraints.positive)
    numpyro.sample("msq", dist.Delta(msq_loc))

    xisq_loc = numpyro.param("xisq_loc", 1.0, constraint=constraints.positive)
    numpyro.sample("xisq", dist.Delta(xisq_loc))

    lam_loc = numpyro.param("lam_loc", 0.5 * np.ones(P), constraint=constraints.positive)
    numpyro.sample("lambda", dist.Delta(lam_loc))

    omega_loc = numpyro.param('omega_loc', -2.0 * np.ones(N))
    omega_scale = numpyro.param('omega_scale', 0.8 * np.ones(N), constraint=constraints.positive)
    base_dist = dist.Normal(omega_loc, omega_scale)
    omega_dist = dist.TransformedDistribution(base_dist, [SigmoidTransform(), AffineTransform(0, 2.5)])
    omega = numpyro.sample("omega", omega_dist)
github pyro-ppl / numpyro / numpyro / infer / reparam.py View on Github external
def __call__(self, name, fn, obs):
        assert obs is None, "LocScaleReparam does not support observe statements"
        centered = self.centered
        if is_identically_one(centered):
            return name, fn, obs
        event_shape = fn.event_shape
        fn, event_dim = self._unwrap(fn)
        fn, batch_shape = self._unexpand(fn)

        # Apply a partial decentering transform.
        params = {key: getattr(fn, key) for key in self.shape_params}
        if self.centered is None:
            centered = numpyro.param("{}_centered".format(name),
                                     jnp.full(event_shape, 0.5),
                                     constraint=constraints.unit_interval)
        params["loc"] = fn.loc * centered
        params["scale"] = fn.scale ** centered
        decentered_fn = type(fn)(**params).expand(batch_shape)

        # Draw decentered noise.
        decentered_value = numpyro.sample("{}_decentered".format(name),
                                          self._wrap(decentered_fn, event_dim))

        # Differentiably transform.
        delta = decentered_value - centered * fn.loc
        value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta

        # Simulate a pyro.deterministic() site.
        return None, value
github pyro-ppl / numpyro / examples / utils.py View on Github external
def record_stats(stat_value, num_stats=2):
    stat = numpyro.param('stats', np.zeros(num_stats)) * stop_gradient(stat_value)
    numpyro.factor('stats_dummy_factor', -stat + stop_gradient(stat))