How to use the jax.random.split function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / funsor / test / View on Github external
expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs)
    sample_inputs = OrderedDict(sample_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    x = random_gaussian(be_inputs)
    rng_key = subkey = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)

    xfail = False
    for num_sampled in range(len(event_inputs) + 1):
        for sampled_vars in itertools.combinations(list(event_inputs), num_sampled):
            sampled_vars = frozenset(sampled_vars)
            print('sampled_vars: {}'.format(', '.join(sampled_vars)))
                if rng_key is not None:
                    import jax
                    rng_key, subkey = jax.random.split(rng_key)

                y = x.sample(sampled_vars, sample_inputs, rng_key=subkey)
            except NotImplementedError:
                xfail = True
            if num_sampled == len(event_inputs):
                assert isinstance(y, (Delta, Contraction))
            if sampled_vars:
                assert dict(y.inputs) == dict(expected_inputs), sampled_vars
                assert y is x
    if xfail:
        pytest.xfail(reason='Not implemented')
github pyro-ppl / numpyro / numpyro / infer / View on Github external
def init_fn(z_info, rng_key, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None):
        :param IntegratorState z_info: The initial integrator state.
        :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
        :param float step_size: Initial step size.
        :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
            inverse of mass matrix will be an identity matrix with size is decided
            by the argument `mass_matrix_size`.
        :param int mass_matrix_size: Size of the mass matrix.
        :return: initial state of the adapt scheme.
        rng_key, rng_key_ss = random.split(rng_key)
        if inverse_mass_matrix is None:
            assert mass_matrix_size is not None
            if dense_mass:
                inverse_mass_matrix = jnp.identity(mass_matrix_size)
                inverse_mass_matrix = jnp.ones(mass_matrix_size)
            mass_matrix_sqrt = inverse_mass_matrix
            if dense_mass:
                mass_matrix_sqrt = cholesky_of_inverse(inverse_mass_matrix)
                mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))

        if adapt_step_size:
            step_size = find_reasonable_step_size(step_size, inverse_mass_matrix, z_info, rng_key_ss)
        ss_state = ss_init(jnp.log(10 * step_size))
github pyro-ppl / numpyro / numpyro / infer / View on Github external
constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            params = transform_fn(inv_transforms,
                                  {k: v for k, v in constrained_values.items()},
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                    params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum)
        pe, z_grad = value_and_grad(potential_fn)(params)
        z_grad_flat = ravel_pytree(z_grad)[0]
        is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        return i + 1, key, (params, pe, z_grad), is_valid
github pyro-ppl / numpyro / numpyro / infer / View on Github external
rng_key, samples = val
        model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
            *model_args, **model_kwargs)
        if return_sites is not None:
            if return_sites == '':
                sites = {k for k, site in model_trace.items() if site['type'] != 'plate'}
                sites = return_sites
            sites = {k for k, site in model_trace.items()
                     if (site['type'] == 'sample' and k not in samples) or (site['type'] == 'deterministic')}
        return {name: site['value'] for name, site in model_trace.items() if name in sites}

    num_samples = int(
    if num_samples > 1:
        rng_key = random.split(rng_key, num_samples)
    rng_key = rng_key.reshape(batch_shape + (2,))
    chunk_size = num_samples if parallel else 1
    return soft_vmap(single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size)
github google / jax / examples / View on Github external
def body_fun(i, opt_state):
      elbo_rng, data_rng = random.split(random.fold_in(rng, i))
      batch = binarize_batch(data_rng, i, train_images)
      loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
      g = grad(loss)(optimizers.get_params(opt_state))
      return opt_update(i, g, opt_state)
    return lax.fori_loop(0, num_batches, body_fun, opt_state)
github pyro-ppl / numpyro / numpyro / infer / View on Github external
def update(self, svi_state, *args, **kwargs):
        Take a single step of SVI (possibly on a batch / minibatch of data),
        using the optimizer.

        :param svi_state: current state of SVI.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple of `(svi_state, loss)`.
        rng_key, rng_key_step = random.split(svi_state.rng_key)
        params = self.optim.get_params(svi_state.optim_state)
        loss_val, grads = value_and_grad(
            lambda x: self.loss.loss(rng_key_step, self.constrain_fn(x), self.model,,
                                     *args, **kwargs, **self.static_kwargs))(params)
        optim_state = self.optim.update(grads, svi_state.optim_state)
        return SVIState(optim_state, rng_key), loss_val
github pyro-ppl / numpyro / numpyro / infer / View on Github external
def init(self, rng_key, *args, **kwargs):

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple containing initial :data:`SVIState`, and `get_params`, a callable
            that transforms unconstrained parameter values from the optimizer to the
            specified constrained domain
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        model_trace = trace(model_init).get_trace(*args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = partial(transform_fn, inv_transforms)
        return SVIState(self.optim.init(params), rng_key)
github pyro-ppl / funsor / funsor / View on Github external
import jax
                    rng_keys = jax.random.split(rng_key, len(self.terms))
                    rng_keys = [None] * len(self.terms)

                # Design choice: we sample over logaddexp reductions, but leave logaddexp
                # binary choices symbolic.
                terms = [
                    term.unscaled_sample(sampled_vars.intersection(term.inputs), sample_inputs)
                    for term, rng_key in zip(self.terms, rng_keys)]
                return Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)

            if self.bin_op is ops.add:
                if rng_key is not None:
                    import jax
                    rng_keys = jax.random.split(rng_key)
                    rng_keys = [None] * 2

                # Sample variables greedily in order of the terms in which they appear.
                for term in self.terms:
                    greedy_vars = sampled_vars.intersection(term.inputs)
                    if greedy_vars:
                greedy_terms, terms = [], []
                for term in self.terms:
                    (terms if greedy_vars.isdisjoint(term.inputs) else greedy_terms).append(term)
                if len(greedy_terms) == 1:
                    term = greedy_terms[0]
                    terms.append(term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]))
                    result = Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)
                elif (len(greedy_terms) == 2 and
github tensorflow / probability / tensorflow_probability / python / internal / backend / numpy / View on Github external
def body_fn(carry):
      """Inner loop of Knuth algorithm."""
      i, k, rng, log_prod = carry
      rng, subkey = random.split(rng)
      k = np.where(log_prod > -lam, k + 1, k)
      return i + 1, k, rng, log_prod + np.log(random.uniform(subkey, shape))
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
def _btrs_body_fn(val):
        _, key, _, _ = val
        key, key_u, key_v = random.split(key, 3)
        u = random.uniform(key_u)
        v = random.uniform(key_v)
        u = u - 0.5
        k = jnp.floor((2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c).astype(n.dtype)
        return k, key, u, v