How to use the jax.random.normal function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tensor2tensor / tensor2tensor / trax / layers / attention.py View on Github external
self.n_buckets % factor == 0 and
            factor % 2 == 0 and
            (self.n_buckets // factor) % 2 == 0):
          factor -= 1
        if factor > 2:  # Factor of 2 does not warrant the effort.
          rot_size = factor + (self.n_buckets // factor)
          factor_list = [factor, self.n_buckets // factor]

    random_rotations_shape = (
        vecs.shape[-1],
        self.n_hashes if self._rehash_each_round else 1,
        rot_size // 2)

    rng = jax.lax.tie_in(vecs, rng)
    rng, subrng = backend.random.split(rng)
    random_rotations = jax.random.normal(
        rng, random_rotations_shape).astype('float32')
    # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
    # hashing, so it's shared between them. Check if that's what we want.
    dropped_vecs = self.drop_for_hash(vecs, subrng)
    rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)

    if self._rehash_each_round:
      if self._factorize_hash and len(factor_list) > 1:
        # We factorized self.n_buckets as the product of factor_list.
        # Get the buckets for them and combine.
        buckets, cur_sum, cur_product = None, 0, 1
        for factor in factor_list:
          rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
          cur_sum += factor // 2
          rv = np.concatenate([rv, -rv], axis=-1)
          if buckets is None:
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def test_improper_normal():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.param('loc', 0., constraint=constraints.interval(0., alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
github google / jax / examples / spmd_spatially_sharded_conv_net.py View on Github external
def init(rng, shape):
    axis_size = lax.psum(1, axis_name)
    fan_in, fan_out = shape[in_axis] * axis_size, shape[out_axis] * axis_size
    size = onp.prod(onp.delete(shape, [in_axis, out_axis]))
    std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
    return std * random.normal(rng, shape, dtype=np.float32)
  return init
github JuliusKunze / jaxnet / examples / wavenet.py View on Github external
def get_batches(batches=100, sequence_length=1000, key=PRNGKey(0)):
        for _ in range(batches):
            key, batch_key = random.split(key)
            yield random.normal(batch_key, (1, receptive_field + sequence_length, 1))
github pyro-ppl / numpyro / numpyro / distributions / lognorm.py View on Github external
def _rvs(self, s):
        return np.exp(s * random.normal(self._random_state, self._size))
github google / jax / examples / mnist_vae.py View on Github external
def gaussian_sample(rng, mu, sigmasq):
  """Sample a diagonal Gaussian."""
  return mu + np.sqrt(sigmasq) * random.normal(rng, mu.shape)
github google / jax / examples / advi.py View on Github external
def diag_gaussian_sample(rng, mean, log_std):
    # Take a single sample from a diagonal multivariate Gaussian.
    return mean + np.exp(log_std) * random.normal(rng, mean.shape)
github google / jax / examples / differentially_private_sgd.py View on Github external
  noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
  normalize_ = lambda n: n / float(batch_size)
github probml / pyprobml / scripts / mnist_vae_jax.py View on Github external
def gaussian_sample(rng, mu, sigmasq):
  """Sample a diagonal Gaussian."""
  return mu + np.sqrt(sigmasq) * random.normal(rng, mu.shape)