How to use the jax.numpy.sum function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / numpyro / contrib / distributions / discrete.py View on Github external
def _entropy(self, n, p):
        if self.is_logits:
            p = expit(p)
        k = jnp.arange(n + 1)
        vals = self._pmf(k, n, p)
        return jnp.sum(entr(vals), axis=0)
github google / jax / examples / gaussian_process_regression.py View on Github external
def marginal_likelihood(params, x, y):
    train_cov = cov(params, x, x) + eye * 1e-6
    chol = np.linalg.cholesky(train_cov + eye * 1e-4).T
    inv_chol = scipy.linalg.solve_triangular(chol, eye, lower=True)
    inv_train_cov = np.dot(inv_chol.T, inv_chol)
    ml = np.sum(
        -0.5 * np.dot(y.T, np.dot(inv_train_cov, y)) -
        0.5 * np.sum(2.0 * np.log(np.dot(inv_chol * eye, np.ones(
            (numpts, 1))))) - (numpts / 2.) * np.log(2. * 3.1415))
    return ml
  grad_fun = jit(grad(marginal_likelihood))
github google / jax / jax / nn / functions.py View on Github external
def log_softmax(x, axis=-1):
  r"""Log-Softmax function.

  Computes the logarithm of the :code:`softmax` function, which rescales
  elements to the range :math:`[-\infty, 0)`.

  .. math ::
    \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
    \right)

  Args:
    axis: the axis or axes along which the :code:`log_softmax` should be
      computed. Either an integer or a tuple of integers.
  """
  shifted = x - x.max(axis, keepdims=True)
  return shifted - np.log(np.sum(np.exp(shifted), axis, keepdims=True))
github google / jax / examples / advi.py View on Github external
def diag_gaussian_logpdf(x, mean, log_std):
    # Evaluate a single point on a diagonal multivariate Gaussian.
    return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std)))
github probml / pyprobml / Old / examples / jax-demo.py View on Github external
def loss(weights, data):
    inputs, targets = data
    preds = predict(weights, inputs)
    label_logprobs = np.log(preds) * targets + np.log(1 - preds) * (1 - targets)
    return -np.sum(label_logprobs)
github google / jax / jax / experimental / ode.py View on Github external
def onearg_odeint(fargs):
      return np.sum(odeint(func, *fargs))
github google / trax / trax / rl / ppo.py View on Github external
def approximate_kl(log_prob_new, log_prob_old, mask):
  """Computes the approximate KL divergence between the old and new log-probs.

  Args:
    log_prob_new: (B, AT, A) log probs new
    log_prob_old: (B, AT, A) log probs old
    mask: (B, AT)

  Returns:
    Approximate KL.
  """
  diff = log_prob_old - log_prob_new
  # Mask out the irrelevant part.
  diff *= mask[:, :, np.newaxis]  # make mask (B, RT, 1)
  # Average on non-masked part.
  return np.sum(diff) / np.sum(mask)
github pyro-ppl / numpyro / benchmarks / sparse_regression.py View on Github external
mcmc = MCMC(kernel, args.num_warmup, args.num_samples,
                num_chains=args.num_chains, progress_bar=not args.disable_progbar)
    tic = time.time()
    mcmc._compile(rng_key, data['X'], data['Y'], extra_fields=('num_steps',))
    print('MCMC (numpyro) compiling time:', time.time() - tic, '\n')
    tic = time.time()
    mcmc.warmup(rng_key, data['X'], data['Y'], extra_fields=('num_steps',))
    rng_key = mcmc._warmup_state.rng_key.copy()
    tic_run = time.time()
    mcmc.run(rng_key, data['X'], data['Y'], extra_fields=('num_steps',))
    mcmc._last_state.rng_key.copy()
    toc = time.time()
    mcmc.print_summary()
    print('\nMCMC (numpyro) elapsed time:', toc - tic)
    sampling_time = toc - tic_run
    num_leapfrogs = np.sum(mcmc.get_extra_fields()['num_steps'])
    print('num leapfrogs', num_leapfrogs)
    time_per_leapfrog = sampling_time / num_leapfrogs
    print('time per leapfrog', time_per_leapfrog)
    n_effs = [effective_sample_size(device_get(v)) for k, v in mcmc.get_samples(group_by_chain=True).items()]
    n_effs = onp.concatenate([onp.array([x]) if np.ndim(x) == 0 else x for x in n_effs])
    n_eff_mean = sum(n_effs) / len(n_effs)
    print('mean n_eff', n_eff_mean)
    time_per_eff_sample = sampling_time / n_eff_mean
    print('time per effective sample', time_per_eff_sample)
    return num_leapfrogs, n_eff_mean, toc - tic, time_per_leapfrog, time_per_eff_sample