How to use the jax.random function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def test_chain(use_init_params, chain_method):
    N, dim = 3000, 3
    num_chains = 2
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.chain_method = chain_method
    init_params = None if not use_init_params else \
        {'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
    mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
    samples_flat = mcmc.get_samples()
github google / jax / examples / mnist_vae.py View on Github external
def image_sample(rng, params, nrow, ncol):
  """Sample images from the generative model."""
  _, dec_params = params
  code_rng, img_rng = random.split(rng)
  logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
  sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits))
  return image_grid(nrow, ncol, sampled_images, (28, 28))
github pyro-ppl / numpyro / numpyro / infer / elbo.py View on Github external
seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo

        # Return (-elbo) since by convention we do gradient descent on a loss and
        # the ELBO is a lower bound that needs to be maximized.
        if self.num_particles == 1:
            return - single_particle_elbo(rng_key)
        else:
            rng_keys = random.split(rng_key, self.num_particles)
            return - np.mean(vmap(single_particle_elbo)(rng_keys))
github tensorflow / tensor2tensor / tensor2tensor / trax / backend.py View on Github external
"name": "jax",
    "np": jnp,
    "logsumexp": jax_special.logsumexp,
    "expit": jax_special.expit,
    "erf": jax_special.erf,
    "conv": jax_conv,
    "avg_pool": jax_avg_pool,
    "max_pool": jax_max_pool,
    "sum_pool": jax_sum_pool,
    "jit": jax.jit,
    "grad": jax.grad,
    "pmap": jax.pmap,
    "eval_on_shapes": jax_eval_on_shapes,
    "random_uniform": jax_random.uniform,
    "random_randint": jax_randint,
    "random_normal": jax_random.normal,
    "random_bernoulli": jax_random.bernoulli,
    "random_get_prng": jax.jit(jax_random.PRNGKey),
    "random_split": jax_random.split,
    "dataset_as_numpy": tfds.as_numpy,
}


_NUMPY_BACKEND = {
    "name": "numpy",
    "np": onp,
    "jit": (lambda f: f),
    "random_get_prng": lambda seed: None,
    "random_split": lambda prng, num=2: (None,) * num,
    "expit": (lambda x: 1. / (1. + onp.exp(-x))),
}
github probml / pyprobml / scripts / mnist_vae_jax.py View on Github external
def evaluate(opt_state, images):
    params = get_params(opt_state)
    elbo_rng, data_rng, image_rng = random.split(test_rng, 3)
    binarized_test = random.bernoulli(data_rng, images)
    test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0]
    sampled_images = image_sample(image_rng, params, nrow, ncol)
    return test_elbo, sampled_images
github pyro-ppl / numpyro / numpyro / infer / hmc_util.py View on Github external
def _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size,
                 going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts):
    key, transition_key = random.split(rng_key)

    new_tree = _iterative_build_subtree(current_tree, vv_update, kinetic_fn,
                                        inverse_mass_matrix, step_size,
                                        going_right, key, energy_current, max_delta_energy,
                                        r_ckpts, r_sum_ckpts)

    return _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, transition_key,
                         True)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
assert len(vecs.shape) == 2
    n_vecs = vecs.shape[0]

    rng1, rng2 = backend.random.split(rng, num=2)

    # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random.
    num_needed = 2 * n_hashes * r_div_2
    if n_vecs < num_needed:
      # shape = (n_hashes, r_div_2)
      random_idxs_1 = jax.random.randint(
          rng1, (n_hashes, r_div_2), 0, n_vecs)
      random_idxs_2 = jax.random.randint(
          rng2, (n_hashes, r_div_2), 0, n_vecs)
    else:
      # Sample without replacement.
      shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs))
      random_idxs = np.reshape(shuffled_indices[:num_needed],
                               (2, n_hashes, r_div_2))
      random_idxs_1 = random_idxs[0]
      random_idxs_2 = random_idxs[1]

    if self._data_rotation_farthest:
      # shape = (n_hashes * r_div_2, )
      random_idxs_1 = np.reshape(random_idxs_1, (-1,))
      random_vecs_1 = vecs[random_idxs_1]

      # Sample candidates for vec2s.
      rng, subrng = backend.random.split(rng)
      # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2)
      candidate_idxs_2 = jax.random.randint(
          subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2), 0,
          n_vecs)
github sharadmv / deepx / deepx / backend / jax.py View on Github external
def __init__(self, seed):
        key = random.PRNGKey(0)
        self.key = key
        self.subkey = key
github pyro-ppl / numpyro / numpyro / contrib / distributions / multivariate.py View on Github external
def _rvs(self, alpha):
        K = alpha.shape[-1]
        gamma_samples = random.gamma(self._random_state, alpha, shape=self._size + (K,))
        return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
github sharadmv / deepx / deepx / backend / jax.py View on Github external
def dropout(self, x, p, seed=None):
        seed = next(self.rng)
        p = 1 - p
        keep = random.bernoulli(seed, p, x.shape)
        return np.where(keep, x / p, 0)