How to use the jax.numpy.sqrt function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / numpyro / distributions / util.py View on Github external
def _von_mises_centered(key, concentration, shape, dtype):
    # Cutoff from TensorFlow probability
    # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570)
    s_cutoff_map = {jnp.dtype(jnp.float16): 1.8e-1,
                    jnp.dtype(jnp.float32): 2e-2,
                    jnp.dtype(jnp.float64): 1.2e-4}
    s_cutoff = s_cutoff_map.get(dtype)

    r = 1. + jnp.sqrt(1. + 4. * concentration ** 2)
    rho = (r - jnp.sqrt(2. * r)) / (2. * concentration)
    s_exact = (1. + rho ** 2) / (2. * rho)

    s_approximate = 1. / concentration

    s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)

    def cond_fn(*args):
        """ check if all are done or reached max number of iterations """
        i, _, done, _, _ = args[0]
        return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))

    def body_fn(*args):
        i, key, done, _, w = args[0]
        uni_ukey, uni_vkey, key = random.split(key, 3)

        u = random.uniform(key=uni_ukey, shape=shape, dtype=concentration.dtype, minval=-1., maxval=1.)
github tensorflow / cleverhans / cleverhans / future / jax / utils.py View on Github external
:param norm: Order of the norm (mimics Numpy).
              Possible values: np.inf or 2.
  :param eps: Epsilon, bound of the perturbation.
  """

  # Clipping perturbation eta to self.norm norm ball
  if norm not in [np.inf, 2]:
    raise ValueError('norm must be np.inf or 2.')

  axis = list(range(1, len(eta.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    eta = np.clip(eta, a_min=-eps, a_max=eps)
  elif norm == 2:
    # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
    norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
    # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
    factor = np.minimum(1., np.divide(eps, norm))
    eta = eta * factor
  return eta
github pyro-ppl / numpyro / examples / gp.py View on Github external
def predict(rng_key, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
    k_XX = kernel(X, X, var, length, noise, include_noise=True)
    K_xx_inv = np.linalg.inv(k_XX)
    K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
    sigma_noise = np.sqrt(np.clip(np.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
    mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise
github google / jax / jax / experimental / odeint.py View on Github external
def f(y, t, arg1, arg2):
    return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
github pyro-ppl / numpyro / numpyro / distributions / transforms.py View on Github external
def inv(self, y):
        # inverse stick-breaking
        z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
        pad_width = [(0, 0)] * y.ndim
        pad_width[-1] = (1, 0)
        z1m_cumprod_shifted = jnp.pad(z1m_cumprod[..., :-1], pad_width,
                                      mode="constant", constant_values=1.)
        t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt(
            matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1))
        # inverse of tanh
        x = jnp.log((1 + t) / (1 - t)) / 2
        return x
github pyro-ppl / funsor / funsor / jax / ops.py View on Github external
def _cholesky(x):
    """
    Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices.
    """
    if x.shape[-1] == 1:
        return np.sqrt(x)
    return np.linalg.cholesky(x)
github google / jax / jax / experimental / ode.py View on Github external
def optimal_step_size(last_step,
                      mean_error_ratio,
                      safety=0.9,
                      ifactor=10.0,
                      dfactor=0.2,
                      order=5.0):
  """Compute optimal Runge-Kutta stepsize."""
  mean_error_ratio = np.max(mean_error_ratio)
  dfactor = np.where(mean_error_ratio < 1,
                     1.0,
                     dfactor)

  err_ratio = np.sqrt(mean_error_ratio)
  factor = np.maximum(1.0 / ifactor,
                      np.minimum(err_ratio**(1.0 / order) / safety,
                                 1.0 / dfactor))
  return np.where(mean_error_ratio == 0,
                  last_step * ifactor,
                  last_step / factor,)
github pyro-ppl / numpyro / numpyro / contrib / distributions / continuous.py View on Github external
def _stats(self, a):
        return a, a, 2.0 / jnp.sqrt(a), 6.0 / a
github scikit-hep / pyhf / src / pyhf / tensor / jax_backend.py View on Github external
def sqrt(self, tensor_in):
        return np.sqrt(tensor_in)
github google / jax / examples / gaussian_process_regression.py View on Github external
# Create a really simple toy 1D function
  y_fun = lambda x: np.sin(x) + 0.01 * random.normal(key, shape=(x.shape[0], 1))
  x = np.linspace(1., 4., numpts)[:, None]
  y = y_fun(x)
  xtest = np.linspace(0, 5., 200)[:, None]
  ytest = y_fun(xtest)

  for i in range(1000):
    params, momentums, scales = train_step(params, momentums, scales, x, y)
    if i % 50 == 0:
      ml = marginal_likelihood(params, x, y)
      print("Step: %d, neg marginal likelihood: %f" % (i, ml))

  print([i.copy() for i in params])
  mu, var = predict(params, x, y, xtest)
  std = np.sqrt(np.diag(var))
  plt.plot(x, y, "k.")
  plt.plot(xtest, mu)
  plt.fill_between(xtest.flatten(),
                   mu.flatten() - std * 2, mu.flatten() + std * 2)