How to use the jax.numpy.mean function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    @jit
    def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob):
        kernel = kernel_cls(model, step_size=step_size, trajectory_length=trajectory_length,
                            target_accept_prob=target_accept_prob)
        mcmc = MCMC(kernel, warmup_steps, num_samples, num_chains=2, chain_method=chain_method,
                    progress_bar=False)
        mcmc.run(rng_key, data)
        return mcmc.get_samples()

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    samples = get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)
github JuliusKunze / jaxnet / jaxnet / modules.py View on Github external
def fastvar(x, axis, keepdims):
    """A fast but less numerically-stable variance calculation than np.var."""
    return np.mean(x ** 2, axis, keepdims=keepdims) - np.mean(x, axis, keepdims=keepdims) ** 2
github tensorflow / cleverhans / tutorials / future / jax / mnist_classifier.py View on Github external
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(logsoftmax(preds) * targets)
github pyro-ppl / numpyro / examples / cg.py View on Github external
D = b.shape[-1]

    presolve = lowrank_presolve(kX, diag, eta1, eta2, c, kappa, rank1, rank2)
    mvm = lambda b: np.matmul(A, b)

    b_probes = np.concatenate([b[None, :], probes])
    Ainv_b_probes, res_norm, iters = pcg_batch_b(b_probes, mvm, presolve=presolve, cg_tol=cg_tol, max_iters=max_iters)
    Ainv_b, Ainv_probes = Ainv_b_probes[0], Ainv_b_probes[1:]

    quad_form_dA = -np.dot(Ainv_b, np.matmul(A_dot, Ainv_b))
    quad_form_db = 2.0 * np.dot(Ainv_b, b_dot)
    log_det_dA = np.mean(np.einsum('...i,...i->...', np.matmul(probes, A_dot), Ainv_probes))
    tangent_out = log_det_dA + quad_form_dA + quad_form_db
    quad_form = np.dot(b, Ainv_b)

    return (quad_form, np.mean(res_norm), np.mean(iters)), (tangent_out, 0.0, 0.0)
github google / trax / trax / rl / ppo.py View on Github external
assert (RT + 1, AT) == rewards_to_actions.shape

  # (B, RT)
  td_deltas = deltas(
      value_predictions_old,  # (B, RT+1)
      padded_rewards,
      reward_mask,
      gamma=gamma)

  # (B, RT)
  advantages = gae_advantages(
      td_deltas, reward_mask, lambda_=lambda_, gamma=gamma)

  # Normalize the advantages.
  advantage_mean = np.mean(advantages)
  advantage_std = np.std(advantages)
  advantages = (advantages - advantage_mean) / (advantage_std + 1e-8)

  # Scatter advantages over padded_actions.
  # rewards_to_actions is RT + 1 -> AT, so we pad the advantages and the reward
  # mask by 1.
  advantages = np.dot(np.pad(advantages, ((0, 0), (0, 1))), rewards_to_actions)
  action_mask = np.dot(
      np.pad(reward_mask, ((0, 0), (0, 1))), rewards_to_actions
  )

  # (B, AT)
  ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old,
                                 padded_actions, action_mask)
  assert (B, AT) == ratios.shape
github bethgelab / foolbox / foolbox / models / jax.py View on Github external
def cross_entropy(logits, labels):
            assert logits.ndim == 2
            assert labels.ndim == 1
            assert len(logits) == len(labels)
            logprobs = logits - logsumexp(logits, axis=1, keepdims=True)
            nll = jnp.take_along_axis(logprobs, jnp.expand_dims(labels, axis=1), axis=1)
            ce = -jnp.mean(nll)
            return ce
github JuliusKunze / jaxnet / examples / wavenet.py View on Github external
def loss(batch):
        theta = wavenet(batch)[:, :-1, :]
        # now slice the padding off the batch
        sliced_batch = batch[:, receptive_field:, :]
        return (np.mean(discretized_mix_logistic_loss(
            theta, sliced_batch, num_class=1 << 16), axis=0)
                * np.log2(np.e) / (output_width - 1))
github pyro-ppl / numpyro / examples / vjp.py View on Github external
kX = kappa * X
    omega_b = b * diag

    mvm = lambda _b: kernel_mvm_diag(_b, kX, eta1, eta2, c, diag, dilation=dilation,dilation2=dilation2)
    presolve = lowrank_presolve(kX, diag, eta1, eta2, c, kappa, rank1, rank2)

    om_b_probes = np.concatenate([omega_b[None, :], probes])
    Ainv_om_b_probes, res_norm, iters = pcg_batch_b(om_b_probes, mvm, presolve=presolve,
                                                    cg_tol=cg_tol, max_iters=max_iters)
    Ainv_om_b, Ainv_probes = Ainv_om_b_probes[0], Ainv_om_b_probes[1:]
    K_Ainv_om_b = kernel_mvm(Ainv_om_b, kX, eta1, eta2, c, dilation=dilation, dilation2=dilation2)
    quad_form = 0.125 * np.dot(b, K_Ainv_om_b)

    residuals = (kX, kappa, eta1, eta2, K_Ainv_om_b, Ainv_om_b, diag, Ainv_probes, probes)

    return (quad_form, np.mean(res_norm), np.mean(iters)), residuals