How to use the jax.numpy.exp function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def test_model_with_multiple_exec_paths(jit_args):
    def model(a=None, b=None, z=None):
        int_term = numpyro.sample('a', dist.Normal(0., 0.2))
        x_term, y_term = 0., 0.
        if a is not None:
            x = numpyro.sample('x', dist.HalfNormal(0.5))
            x_term = a * x
        if b is not None:
            y = numpyro.sample('y', dist.HalfNormal(0.5))
            y_term = b * y
        sigma = numpyro.sample('sigma', dist.Exponential(1.))
        mu = int_term + x_term + y_term
        numpyro.sample('obs', dist.Normal(mu, sigma), obs=z)

    a = jnp.exp(np.random.randn(10))
    b = jnp.exp(np.random.randn(10))
    z = np.random.randn(10)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 20, 10, jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(1), a, b=None, z=z)
    assert set(mcmc.get_samples()) == {'a', 'x', 'sigma'}
    mcmc.run(random.PRNGKey(2), a=None, b=b, z=z)
    assert set(mcmc.get_samples()) == {'a', 'y', 'sigma'}
    mcmc.run(random.PRNGKey(3), a=a, b=b, z=z)
    assert set(mcmc.get_samples()) == {'a', 'x', 'y', 'sigma'}
github sharadmv / deepx / deepx / backend / jax.py View on Github external
def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))
github pyro-ppl / numpyro / numpyro / infer / sa.py View on Github external
scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_
        log_weights_ = jnp.where(jnp.isnan(log_weights_), -jnp.inf, log_weights_)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
github pyro-ppl / numpyro / examples / stochastic_volatility.py View on Github external
def model(returns):
    step_size = numpyro.sample('sigma', dist.Exponential(50.))
    s = numpyro.sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0]))
    nu = numpyro.sample('nu', dist.Exponential(.1))
    return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(s)),
                          obs=returns)
github google / jax / custom_vjps3.py View on Github external
@primitive
def logsumexp(x):
  max_x = np.max(x)
  return max_x + np.log(np.sum(np.exp(x - max_x)))
github pyro-ppl / numpyro / numpyro / distributions / transforms.py View on Github external
def __call__(self, x):
        # XXX consider to clamp from below for stability if necessary
        return jnp.exp(x)
github google / jax / jax / nn / functions.py View on Github external
def softmax(x, axis=-1):
  r"""Softmax function.

  Computes the function which rescales elements to the range :math:`[0, 1]`
  such that the elements along :code:`axis` sum to :math:`1`.

  .. math ::
    \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

  Args:
    axis: the axis or axes along which the softmax should be computed. The
      softmax output summed across these dimensions should sum to :math:`1`.
      Either an integer or a tuple of integers.
  """
  unnormalized = np.exp(x - x.max(axis, keepdims=True))
  return unnormalized / unnormalized.sum(axis, keepdims=True)
github pyro-ppl / numpyro / numpyro / contrib / distributions / discrete.py View on Github external
def _logpmf(self, x, n, p):
        x, n, p = _promote_dtypes(x, n, p)
        combiln = gammaln(n + 1) - (gammaln(x + 1) + gammaln(n - x + 1))
        if self.is_logits:
            # TODO: move this implementation to PyTorch if it does not get non-continuous problem
            # In PyTorch, k * logit - n * log1p(e^logit) get overflow when logit is a large
            # positive number. In that case, we can reformulate into
            # k * logit - n * log1p(e^logit) = k * logit - n * (log1p(e^-logit) + logit)
            #                                = k * logit - n * logit - n * log1p(e^-logit)
            # More context: https://github.com/pytorch/pytorch/pull/15962/
            return combiln + x * p - (n * jnp.clip(p, 0) + xlog1py(n, jnp.exp(-jnp.abs(p))))
        else:
            return combiln + xlogy(x, p) + xlog1py(n - x, -p)
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
def log_prob(self, value):
        z = (value - self.loc) / self.scale
        return -(z + jnp.exp(-z)) - jnp.log(self.scale)