How to use the jax.numpy function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github sharadmv / deepx / deepx / backend / jax.py View on Github external
def cholesky(self, A, lower=True, warn=False, correct=True):
        return np.linalg.cholesky(A)
github probml / pyprobml / scripts / einsum_demo.py View on Github external
s = 0
        for d in range(D):
            for k in range(K):
                for t in range(T):
                    s += S[n,t,k] * W[k,d] * V[d,c]
        L[n,c] = s
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V))


path = np.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))


import jax.numpy as jnp
path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))

# Use full student network from KOller and Friedman
str = 'c,dc,gdi,si,lg,jls,hgj->'
K = 5
cptC = np.random.randn(K)
cptD = np.random.randn(K,K)
cptG = np.random.randn(K,K,K)
cptS = np.random.randn(K,K)
cptL = np.random.randn(K,K)
cptJ = np.random.randn(K,K,K)
cptH = np.random.randn(K,K,K)
cpts = [cptC, cptD, cptG, cptS, cptL, cptJ, cptH]
path_info = np.einsum_path(str, *cpts, optimize='optimal')
print(path_info[0]) # 'einsum_path', (0, 1), (0, 5), (0, 4), (0, 3), (0, 2), (0, 1)]
print(path_info[1])
'''
github team-ocean / veros / veros / core / operators.py View on Github external
def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x
        cp = c / (b - a * last_cp)
        dp = (d - a * last_dp) / (b - a * last_cp)
        new_primes = np.stack((cp, dp))
        return new_primes, new_primes
github pyro-ppl / numpyro / examples / cg.py View on Github external
def pcg_body_fun(state, mvm, presolve):
    x, r, p, z, r_dot_z, iteration = state
    Ap = mvm(p)
    alpha = r_dot_z / np.dot(p, Ap)
    x = x + alpha * p
    r = r - alpha * Ap
    z = presolve(r)
    beta_denom = r_dot_z
    r_dot_z = np.dot(r, z)
    beta = r_dot_z / beta_denom
    p = z + beta * p
    return PCGState(x, r, p, z, r_dot_z, iteration + 1)
github google / jax / jax / random.py View on Github external
msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}"
    raise ValueError(msg.format(onp.ndim(cov)))
  n = mean.shape[-1]
  if onp.shape(cov)[-2:] != (n, n):
    msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
           "but got cov.shape == {shape}.")
    raise ValueError(msg.format(n=n, shape=onp.shape(cov)))

  if shape is None:
    shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
  else:
    _check_shape("normal", shape, mean.shape[:-1], mean.shape[:-2])

  chol_factor = cholesky(cov)
  normal_samples = normal(key, shape + mean.shape[-1:], dtype)
  return mean + np.tensordot(normal_samples, chol_factor, [-1, 1])
github pyro-ppl / numpyro / numpyro / infer / sa.py View on Github external
def _sample_proposal(inv_mass_matrix_sqrt, rng_key, batch_shape=()):
    eps = random.normal(rng_key, batch_shape + jnp.shape(inv_mass_matrix_sqrt)[:1])
    if inv_mass_matrix_sqrt.ndim == 1:
        r = jnp.multiply(inv_mass_matrix_sqrt, eps)
    elif inv_mass_matrix_sqrt.ndim == 2:
        r = jnp.matmul(inv_mass_matrix_sqrt, eps[..., None])[..., 0]
    else:
        raise ValueError("Mass matrix has incorrect number of dims.")
    return r
github google / jax / jax / experimental / odeint.py View on Github external
def interp_fit_dopri(y0, y1, k, dt):
  # Fit a polynomial to the results of a Runge-Kutta step.
  y_mid = y0 + dt * np.dot(dps_c_mid, k)
  return np.array(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
github google / jax / jax / experimental / stax.py View on Github external
def apply_fun(params, x, **kwargs):
    beta, gamma = params
    # TODO(phawkins): np.expand_dims should accept an axis tuple.
    # (https://github.com/numpy/numpy/issues/12290)
    ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
    beta = beta[ed]
    gamma = gamma[ed]
    z = normalize(x, axis, epsilon=epsilon)
    if center and scale: return gamma * z + beta
    if center: return z + beta
    if scale: return gamma * z
    return z
  return init_fun, apply_fun
github pyro-ppl / numpyro / examples / sparse_regression.py View on Github external
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
    L = cho_factor(k_xx, lower=True)[0]
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    mu = jnp.matmul(k_probeX, cho_solve((L, True), Y))
    mu = jnp.sum(mu * vec, axis=-1)

    Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True)
    covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX)
    covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))

    # sample from N(mu, covar)
    L = jnp.linalg.cholesky(covar)
    sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))

    return sample
github tensorflow / probability / discussion / fun_mcmc / tf_on_jax.py View on Github external
_impl_np()(np.where)
_impl_np()(np.zeros)
_impl_np()(np.zeros_like)
_impl_np()(np.maximum)
_impl_np()(np.minimum)
_impl_np(['math'])(np.log)
_impl_np(['math'])(np.sqrt)
_impl_np(['math'], name='pow')(np.power)
_impl_np(['math'], name='reduce_prod')(np.prod)
_impl_np(['math'], name='reduce_variance')(np.var)
_impl_np(name='abs')(np.abs)
_impl_np(name='Tensor')(np.ndarray)
_impl_np(name='concat')(np.concatenate)
_impl_np(name='constant')(np.array)
_impl_np(name='expand_dims')(np.expand_dims)
_impl_np(name='range')(np.arange)
_impl_np(name='reduce_max')(np.max)
_impl_np(name='reduce_mean')(np.mean)
_impl_np(name='reduce_sum')(np.sum)
_impl_np(name='square')(np.square)