How to use the jax.numpy.arange function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / examples / mvm.py View on Github external
def do_mvm(rhs):
        @jit
        def compute_element(i):
            return np.dot(rhs, row(i))
        return _chunk_vmap(compute_element, np.arange(rhs.shape[-1]), rhs.shape[-1] // dilation)
    return do_mvm
github pyro-ppl / numpyro / test / contrib / test_control_flow.py View on Github external
def model(T=10, q=1, r=1, phi=0., beta=0.):

        def transition(state, i):
            x0, mu0 = state
            x1 = numpyro.sample('x', dist.Normal(phi * x0, q))
            mu1 = beta * mu0 + x1
            y1 = numpyro.sample('y', dist.Normal(mu1, r))
            numpyro.deterministic('y2', y1 * 2)
            return (x1, mu1), (x1, y1)

        mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
        y0 = numpyro.sample('y_0', dist.Normal(mu0, r))

        _, xy = scan(transition, (x0, mu0), jnp.arange(T))
        x, y = xy

        return jnp.append(x0, x), jnp.append(y0, y)
github pyro-ppl / numpyro / numpyro / contrib / nn / auto_reg_nn.py View on Github external
output_multiplier = sum(param_dims)
    all_ones = (np.array(param_dims) == 1).all()

    # Calculate the indices on the output corresponding to each parameter
    ends = np.cumsum(np.array(param_dims), axis=0)
    starts = np.concatenate((np.zeros(1), ends[:-1]))
    param_slices = [slice(int(s), int(e)) for s, e in zip(starts, ends)]

    # Hidden dimension must be not less than the input otherwise it isn't
    # possible to connect to the outputs correctly
    for h in hidden_dims:
        if h < input_dim:
            raise ValueError('Hidden dimension must not be less than input dimension.')

    if permutation is None:
        permutation = np.arange(input_dim)

    # Create masks
    masks, mask_skip = create_mask(input_dim=input_dim, hidden_dims=hidden_dims,
                                   permutation=permutation,
                                   output_dim_multiplier=output_multiplier)

    main_layers = []
    # Create masked layers
    for i, mask in enumerate(masks):
        main_layers.append(MaskedDense(mask))
        if i < len(masks) - 1:
            main_layers.append(nonlinearity)

    if skip_connections:
        net_init, net = stax.serial(stax.FanOut(2),
                                    stax.parallel(stax.serial(*main_layers),
github pyro-ppl / numpyro / examples / ode.py View on Github external
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
github pyro-ppl / numpyro / examples / mvm.py View on Github external
def do_mvm(rhs):
        M = vmap(row)(np.arange(N))
        return np.matmul(M, rhs)
    return do_mvm
github pyro-ppl / numpyro / examples / analysis.py View on Github external
def fun(omega):
            _fun = lambda dim: process_singleton_pcg(dim, P, kappa, kX, omega, Y, eta1, eta2, c, rank1, rank2,
                                                     cg_tol=cg_tol, max_iters=max_iters)
            return chunk_vmap(_fun, np.arange(P), chunk_size=probe_chunk_size)
github tensorflow / tensor2tensor / tensor2tensor / trax / rl / ppo.py View on Github external
def chosen_probabs(probab_actions, actions):
  """Picks out the probabilities of the actions along batch and time-steps.

  Args:
    probab_actions: ndarray of shape `[B, AT, A]`, where
      probab_actions[b, t, i] contains the log-probability of action = i at
      the t^th time-step in the b^th trajectory.
    actions: ndarray of shape `[B, AT]`, with each entry in [0, A) denoting
      which action was chosen in the b^th trajectory's t^th time-step.

  Returns:
    `[B, AT, A]` ndarray with the log-probabilities of the chosen actions.
  """
  B, AT = actions.shape  # pylint: disable=invalid-name
  assert (B, AT) == probab_actions.shape[:2]
  return probab_actions[np.arange(B)[:, None], np.arange(AT), actions]
github sharadmv / deepx / deepx / backend / jax.py View on Github external
def range(self, start, limit=None, delta=1):
        return np.arange(start, limit, step=delta)
github pyro-ppl / numpyro / examples / cg.py View on Github external
def lowrank_presolve(kX, D, eta1, eta2, c, kappa, rank1, rank2):
    N, P = kX.shape
    all_ones = np.ones((N, 1))
    kappa_indices = np.argsort(kappa)

    top_features = dynamic_slice_in_dim(kappa_indices, P - rank1, rank1)
    kX_top = np.take(kX, top_features, -1)

    if rank2 > 0:
        top_features2 = dynamic_slice_in_dim(kappa_indices, P - rank2, rank2)
        kX_top2 = np.take(kX, top_features2, -1)  # N rank2
        kX_top2 = kX_top2[:, None, :] * kX_top2[:, :, None] # N rank2 rank2
        lower_diag = np.ravel(np.arange(rank2) < np.arange(rank2)[:, None])
        kX_top2 = np.compress(lower_diag, kX_top2.reshape((N, -1)), axis=-1)

        Z = np.concatenate([eta2 * kX_top2, eta1 * kX_top, c * all_ones], axis=1)
    else:
        Z = np.concatenate([eta1 * kX_top, c * all_ones], axis=1)

    ZD = Z / D[:, None]
    ZDZ = np.eye(ZD.shape[-1]) + np.matmul(np.transpose(Z), ZD)
    L = cho_factor(ZDZ, lower=True)[0]
    return lambda b: b / D - np.matmul(ZD, cho_solve((L, True), np.matmul(np.transpose(ZD), b)))