How to use the jax.numpy.array function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_optimizers.py View on Github external
def test_optim_multi_params(optim_class, args):
    params = {'x': np.array([1., 1., 1.]), 'y': np.array([-1, -1., -1.])}
    opt = optim_class(*args)
    opt_state = opt.init(params)
    for i in range(2000):
        opt_state = step(opt_state, opt)
    for _, param in opt.get_params(opt_state).items():
        assert np.allclose(param, np.zeros(3))
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def test_dirichlet_categorical(kernel_cls, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
github pyro-ppl / numpyro / examples / vjp.py View on Github external
return (quad_form, np.mean(res_norm), np.mean(iters)), (tangent_out, 0.0, 0.0)


if __name__ == '__main__':
    N = 5
    P = 4
    b = np.array(onp.random.randn(N))
    X = np.array(onp.random.randn(N * P).reshape((N, P)))
    dkX = np.array(onp.random.randn(N * P).reshape((N, P)))
    #Ainv_b_probes = np.array(onp.random.randn(N * 2).reshape((2, N)))
    #probes = np.array(onp.random.randn(N * 1).reshape((1, N)))
    kappa = 0.3 + 2.0 * np.array(onp.random.rand(P))
    eta1 = 0.8
    eta2 = 0.5
    diag = np.array(onp.random.rand(N))
    c = 1.0
    #num_probes = 1
    #probes = np.array(onp.random.randn(N * num_probes).reshape((num_probes, N)))
    probes = math.sqrt(N) * np.eye(N)

    def direct(_kappa, _b, _eta1, _eta2, _diag, include_log_det):
        kX = _kappa * X
        k = kernel(kX, kX, _eta1, _eta2, c)
        k_diag = k + np.diag(_diag)
        return direct_quad_form_log_det(k_diag, np.matmul(k, _b), _b * _diag, include_log_det=include_log_det)

    def pcpcg(_kappa, _b, _eta1, _eta2, _diag):
        return pcpcg_quad_form_log_det2(_kappa, _b, _eta1, _eta2, _diag, c, X,
                                       probes, 2, 2, 1.0e-5, 400, 1)[0]

    which = 3
github pyro-ppl / numpyro / examples / pairwise.py View on Github external
def do_chunk(svi_state):
        return _fori_loop(np.array(0), np.array(report_frequency), body_fn, (svi_state, np.array(0.0), np.zeros(2)))
github google / jax / jax / experimental / ode.py View on Github external
y0: initial value for the state.
      f0: initial value for the derivative, computed from `func(t0, y0)`.
      t0: initial time.
      dt: time step.
      alpha, beta, c: Butcher tableau describing how to take the Runge-Kutta
        step.

  Returns:
      y1: estimated function at t1 = t0 + dt
      f1: derivative of the state at t1
      y1_error: estimated error at t1
      k: list of Runge-Kutta coefficients `k` used for calculating these terms.
  """
  # Dopri5 Butcher tableaux
  alpha = np.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
  beta = np.array(
      [[1 / 5, 0, 0, 0, 0, 0, 0],
       [3 / 40, 9 / 40, 0, 0, 0, 0, 0],
       [44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
       [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
       [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
       [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]])
  c_sol = np.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84,
                    0])
  c_error = np.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
                      125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
                      11 / 84 - 649 / 6300, -1. / 60.])

  def _fori_body_fun(i, val):
    ti = t0 + dt * alpha[i-1]
    yi = y0 + dt * np.dot(beta[i-1, :], val)
    ft = func(yi, ti)
github google / TensorNetwork / examples / simple_mera / simple_mera.py View on Github external
def ham_ising():
  """Dimension 2 "Ising" Hamiltonian.

  This version from Evenbly & White, Phys. Rev. Lett. 116, 140403 (2016).
  """
  E = np.array([[1, 0], [0, 1]])
  X = np.array([[0, 1], [1, 0]])
  Z = np.array([[1, 0], [0, -1]])
  hmat = np.kron(X, np.kron(Z, X))
  hmat -= 0.5 * (np.kron(np.kron(X, X), E) + np.kron(E, np.kron(X, X)))
  return np.reshape(hmat, [2]*6)
github pyro-ppl / numpyro / examples / sparse_regression.py View on Github external
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, sigma):
    P, N = X.shape[1], X.shape[0]

    probe = jnp.zeros((4, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim1], jnp.array([1.0, 1.0, -1.0, -1.0]))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim2], jnp.array([1.0, -1.0, 1.0, -1.0]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    vec = jnp.array([0.25, -0.25, -0.25, 0.25])
    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.dot(mu, vec)
github pyro-ppl / numpyro / examples / analysis.py View on Github external
def process_singleton_pcg(dim, P, kappa, kX, omega, Y, eta1, eta2, c, rank1, rank2,
                          cg_tol=1.0e-3, max_iters=200):
    probe = np.zeros((2, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim], np.array([1.0, -1.0]))
    vec = np.array([0.50, -0.50])
    mu, var = process_probe_pcg(kappa * probe, kX, kappa, omega, Y, vec, eta1, eta2, c, rank1, rank2,
                                cg_tol=cg_tol, max_iters=max_iters)
    return mu, var
github pyro-ppl / numpyro / examples / vjp.py View on Github external
- 2.0 * meansum(probes_kX * Ainv_probes_kX) \
                                       - meansum(probes_ksqXsq * Ainv_probes_ksqXsq) \
                                       - np.mean(np.sum(probes, axis=-1) * np.sum(Ainv_probes, axis=-1)))
    log_det_ddiag = meansum(probes * diag_dot * Ainv_probes)

    tangent_out = -0.125 * (quad_form_dk + quad_form_deta1 + quad_form_deta2 + quad_form_ddiag - quad_form_db) + \
                  -0.5 * (log_det_dk + log_det_deta1 + log_det_deta2 + log_det_ddiag)
    quad_form = 0.125 * np.dot(Kb, Ainv_b)

    return (quad_form, np.mean(res_norm), np.mean(iters)), (tangent_out, 0.0, 0.0)


if __name__ == '__main__':
    N = 5
    P = 4
    b = np.array(onp.random.randn(N))
    X = np.array(onp.random.randn(N * P).reshape((N, P)))
    dkX = np.array(onp.random.randn(N * P).reshape((N, P)))
    #Ainv_b_probes = np.array(onp.random.randn(N * 2).reshape((2, N)))
    #probes = np.array(onp.random.randn(N * 1).reshape((1, N)))
    kappa = 0.3 + 2.0 * np.array(onp.random.rand(P))
    eta1 = 0.8
    eta2 = 0.5
    diag = np.array(onp.random.rand(N))
    c = 1.0
    #num_probes = 1
    #probes = np.array(onp.random.randn(N * num_probes).reshape((num_probes, N)))
    probes = math.sqrt(N) * np.eye(N)

    def direct(_kappa, _b, _eta1, _eta2, _diag, include_log_det):
        kX = _kappa * X
        k = kernel(kX, kX, _eta1, _eta2, c)