How to use the jax.numpy.log function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / numpyro / infer / hmc_util.py View on Github external
Matthew D. Hoffman, Andrew Gelman

    :param potential_fn: A callable to compute potential energy.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param momentum_generator: A generator to get a random momentum variable.
    :param float init_step_size: Initial step size to be tuned.
    :param inverse_mass_matrix: Inverse of mass matrix.
    :param IntegratorState z_info: The current integrator state.
    :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
    :return: a reasonable value for step size.
    :rtype: float
    """
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = jnp.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z, _, potential_energy, z_grad = z_info
    if potential_energy is None or z_grad is None:
        potential_energy, z_grad = value_and_grad(potential_fn)(z)
    finfo = jnp.finfo(get_dtype(init_step_size))

    def _body_fn(state):
        step_size, _, direction, rng_key = state
        rng_key, rng_key_momentum = random.split(rng_key)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0 ** direction) * step_size
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
def _batch_lowrank_logdet(W, D, capacitance_tril):
    r"""
    Uses "matrix determinant lemma"::
        log|W @ W.T + D| = log|C| + log|D|,
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
    the log determinant.
    """
    return 2 * jnp.sum(jnp.log(jnp.diagonal(capacitance_tril, axis1=-2, axis2=-1)), axis=-1) + jnp.log(D).sum(-1)
github pyro-ppl / numpyro / examples / ode.py View on Github external
"""
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
github pyro-ppl / numpyro / numpyro / distributions / normal.py View on Github external
# Source code modified from scipy.stats._continuous_distns.py
#
# Copyright (c) 2001, 2002 Enthought, Inc.
# All rights reserved.
#
# Copyright (c) 2003-2019 SciPy Developers.
# All rights reserved.

import jax.numpy as np
import jax.random as random

from numpyro.distributions.distribution import jax_continuous


_norm_pdf_C = np.sqrt(2 * np.pi)
_norm_pdf_logC = np.log(_norm_pdf_C)


def _norm_pdf(x):
    return np.exp(-x ** 2 / 2.0) / _norm_pdf_C


def _norm_logpdf(x):
    return -x ** 2 / 2.0 - _norm_pdf_logC


class norm_gen(jax_continuous):
    def _rvs(self):
        return random.normal(self._random_state, self._size)

    def _stats(self):
        return 0.0, 1.0, 0.0, 0.0
github pyro-ppl / numpyro / numpyro / contrib / distributions / continuous.py View on Github external
def _logpdf(self, x):
        return 0.5 * jnp.log(2.0 / jnp.pi) - x * x / 2.0
github pyro-ppl / numpyro / examples / baseball.py View on Github external
def predict(model, at_bats, hits, z, rng_key, player_names, train=True):
    header = model.__name__ + (' - TRAIN' if train else ' - TEST')
    predictions = Predictive(model, posterior_samples=z)(rng_key, at_bats)['obs']
    print_results('=' * 30 + header + '=' * 30,
                  predictions,
                  player_names,
                  at_bats,
                  hits)
    if not train:
        post_loglik = log_likelihood(model, z, at_bats, hits)['obs']
        # computes expected log predictive density at each data point
        exp_log_density = logsumexp(post_loglik, axis=0) - jnp.log(jnp.shape(post_loglik)[0])
        # reports log predictive density of all test points
        print('\nLog pointwise predictive density: {:.2f}\n'.format(exp_log_density.sum()))
github JuliusKunze / jaxnet / examples / wavenet.py View on Github external
# log probability for edge case of 0 (before scaling):
    log_cdf_plus = plus_in - softplus(plus_in)
    # log probability for edge case of 255 (before scaling):
    log_one_minus_cdf_min = - softplus(min_in)

    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_y

    log_pdf_mid = mid_in - log_scales - 2. * softplus(mid_in)

    log_probs = np.where(
        y < -0.999, log_cdf_plus,
        np.where(y > 0.999, log_one_minus_cdf_min,
                 np.where(cdf_delta > 1e-5,
                          np.log(np.maximum(cdf_delta, 1e-12)),
                          log_pdf_mid - np.log((num_class - 1) / 2))))

    log_probs = log_probs + log_softmax(logit_probs)
    return -np.sum(logsumexp(log_probs, axis=-1), axis=-1)
github pyro-ppl / numpyro / numpyro / contrib / distributions / multivariate.py View on Github external
def logpmf(self, x, p):
        batch_shape = lax.broadcast_shapes(x.shape, p.shape[:-1])
        # append a dimension to x
        # TODO: consider to convert x.dtype to int
        x = jnp.expand_dims(x, axis=-1)
        x = jnp.broadcast_to(x, batch_shape + (1,))
        p = jnp.broadcast_to(p, batch_shape + p.shape[-1:])
        if self.is_logits:
            # normalize log prob
            p = p - logsumexp(p, axis=-1, keepdims=True)
            # gather and remove the trailing dimension
            return jnp.take_along_axis(p, x, axis=-1)[..., 0]
        else:
            return jnp.take_along_axis(jnp.log(p), x, axis=-1)[..., 0]
github pyro-ppl / numpyro / numpyro / distributions / util.py View on Github external
def logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
    y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
    xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
    return xy + x_shift + y_shift
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
def log_prob(self, value):
        M = _batch_mahalanobis(self.scale_tril, value - self.loc)
        half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1)
        normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log(2 * jnp.pi)
        return - 0.5 * M - normalize_term