# How to use the jax.numpy.sqrt function in jax

## To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

pyro-ppl / numpyro / numpyro / distributions / util.py View on Github
``````def _von_mises_centered(key, concentration, shape, dtype):
# Cutoff from TensorFlow probability
# (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570)
s_cutoff_map = {jnp.dtype(jnp.float16): 1.8e-1,
jnp.dtype(jnp.float32): 2e-2,
jnp.dtype(jnp.float64): 1.2e-4}
s_cutoff = s_cutoff_map.get(dtype)

r = 1. + jnp.sqrt(1. + 4. * concentration ** 2)
rho = (r - jnp.sqrt(2. * r)) / (2. * concentration)
s_exact = (1. + rho ** 2) / (2. * rho)

s_approximate = 1. / concentration

s = jnp.where(concentration &gt; s_cutoff, s_exact, s_approximate)

def cond_fn(*args):
""" check if all are done or reached max number of iterations """
i, _, done, _, _ = args[0]
return jnp.bitwise_and(i &lt; 100, jnp.logical_not(jnp.all(done)))

def body_fn(*args):
i, key, done, _, w = args[0]
uni_ukey, uni_vkey, key = random.split(key, 3)

u = random.uniform(key=uni_ukey, shape=shape, dtype=concentration.dtype, minval=-1., maxval=1.)``````
tensorflow / cleverhans / cleverhans / future / jax / utils.py View on Github
``````:param norm: Order of the norm (mimics Numpy).
Possible values: np.inf or 2.
:param eps: Epsilon, bound of the perturbation.
"""

# Clipping perturbation eta to self.norm norm ball
if norm not in [np.inf, 2]:
raise ValueError('norm must be np.inf or 2.')

axis = list(range(1, len(eta.shape)))
avoid_zero_div = 1e-12
if norm == np.inf:
eta = np.clip(eta, a_min=-eps, a_max=eps)
elif norm == 2:
# avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
# We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
factor = np.minimum(1., np.divide(eps, norm))
eta = eta * factor
return eta
``````
pyro-ppl / numpyro / examples / gp.py View on Github
``````def predict(rng_key, X, Y, X_test, var, length, noise):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_inv = np.linalg.inv(k_XX)
K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
sigma_noise = np.sqrt(np.clip(np.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean, mean + sigma_noise``````
google / jax / jax / experimental / odeint.py View on Github
``````def f(y, t, arg1, arg2):
return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)``````
pyro-ppl / numpyro / numpyro / distributions / transforms.py View on Github
``````def inv(self, y):
# inverse stick-breaking
z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
pad_width = [(0, 0)] * y.ndim
mode="constant", constant_values=1.)
t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt(
matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1))
# inverse of tanh
x = jnp.log((1 + t) / (1 - t)) / 2
return x``````
pyro-ppl / funsor / funsor / jax / ops.py View on Github
``````def _cholesky(x):
"""
Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices.
"""
if x.shape[-1] == 1:
return np.sqrt(x)
return np.linalg.cholesky(x)``````
google / jax / jax / experimental / ode.py View on Github
``````def optimal_step_size(last_step,
mean_error_ratio,
safety=0.9,
ifactor=10.0,
dfactor=0.2,
order=5.0):
"""Compute optimal Runge-Kutta stepsize."""
mean_error_ratio = np.max(mean_error_ratio)
dfactor = np.where(mean_error_ratio &lt; 1,
1.0,
dfactor)

err_ratio = np.sqrt(mean_error_ratio)
factor = np.maximum(1.0 / ifactor,
np.minimum(err_ratio**(1.0 / order) / safety,
1.0 / dfactor))
return np.where(mean_error_ratio == 0,
last_step * ifactor,
last_step / factor,)``````
pyro-ppl / numpyro / numpyro / contrib / distributions / continuous.py View on Github
``````def _stats(self, a):
return a, a, 2.0 / jnp.sqrt(a), 6.0 / a``````
scikit-hep / pyhf / src / pyhf / tensor / jax_backend.py View on Github
``````def sqrt(self, tensor_in):
return np.sqrt(tensor_in)``````
google / jax / examples / gaussian_process_regression.py View on Github
``````# Create a really simple toy 1D function
y_fun = lambda x: np.sin(x) + 0.01 * random.normal(key, shape=(x.shape[0], 1))
x = np.linspace(1., 4., numpts)[:, None]
y = y_fun(x)
xtest = np.linspace(0, 5., 200)[:, None]
ytest = y_fun(xtest)

for i in range(1000):
params, momentums, scales = train_step(params, momentums, scales, x, y)
if i % 50 == 0:
ml = marginal_likelihood(params, x, y)
print("Step: %d, neg marginal likelihood: %f" % (i, ml))

print([i.copy() for i in params])
mu, var = predict(params, x, y, xtest)
std = np.sqrt(np.diag(var))
plt.plot(x, y, "k.")
plt.plot(xtest, mu)
plt.fill_between(xtest.flatten(),
mu.flatten() - std * 2, mu.flatten() + std * 2)``````

## jax

Differentiate, compile, and transform Numpy code.

Apache-2.0