How to use the jax.jit function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github JuliusKunze / jaxnet / tests / test_jaxnet.py View on Github external
def test_external_submodule():
    layer = Dense(3)

    @parametrized
    def net_fun(inputs):
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net_fun.init_params(PRNGKey(0), inputs)
    out = net_fun.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net_fun.apply(params, inputs)
    assert np.array_equal(out, out_)

    out_ = jit(net_fun.apply)(params, inputs)
    assert np.allclose(out, out_)
github JuliusKunze / jaxnet / tests / test_jaxnet2.py View on Github external
def test_external_sequential_submodule():
    layer = Sequential(Dense(2, zeros, zeros), relu)
    inputs = np.zeros((1, 2))

    params = layer.init_params(PRNGKey(0), inputs)
    assert_params_equal(((np.zeros((2, 2)), np.zeros(2)),), params)

    out = layer.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = jit(layer.apply)(params, inputs)
    assert np.array_equal(out, out_)
github pyro-ppl / numpyro / test / contrib / test_funsor_tve_smoke.py View on Github external
def test_optimized_plated_einsum_smoke(equation, plates, backend, sizes):

    jit_raw_einsum = jax.jit(jax.value_and_grad(functools.partial(
        raw_einsum, equation=equation, plates=plates, backend=backend)))

    for i in range(2):
        operands = make_einsum_example(equation, sizes=sizes)[3]
        actual, grads = jit_raw_einsum(operands)
        assert jnp.ndim(actual) == 0
github google / jax / examples / kernel_lsq.py View on Github external
def train(kernel, xs, ys, regularization=0.01):
  gram_ = jit(partial(gram, kernel))
  gram_mat = gram_(xs)
  n = xs.shape[0]

  def objective(v):
    risk = .5 * np.sum((np.dot(gram_mat, v) - ys) ** 2.0)
    reg = regularization * np.sum(v ** 2.0)
    return risk + reg

  v = minimize(objective, np.zeros(n))

  def predict(x):
    prods = vmap(lambda x_: kernel(x, x_))(xs)
    return np.sum(v * prods)

  return jit(vmap(predict))
github pyro-ppl / numpyro / numpyro / util.py View on Github external
collection = jnp.zeros((collection_size,) + init_val_flat.shape)
    if not progbar:
        last_val, collection, _ = fori_loop(0, upper, _body_fn, (init_val, collection, lower))
    else:
        diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
        progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '')

        vals = (init_val, collection, device_put(lower))
        if upper == 0:
            # special case, only compiling
            jit(_body_fn)(0, vals)
        else:
            with tqdm.trange(upper) as t:
                for i in t:
                    vals = jit(_body_fn)(i, vals)
                    t.set_description(progbar_desc(i), refresh=False)
                    if diagnostics_fn:
                        t.set_postfix_str(diagnostics_fn(vals[0]), refresh=False)

        last_val, collection, _ = vals

    unravel_collection = vmap(unravel_fn)(collection)
    return (unravel_collection, last_val) if return_last_val else unravel_collection
github google / jax / jax / experimental / ode.py View on Github external
@jax.jit
def swoop(y, t, arg1, arg2):
  return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2)
github google / trax / trax / rl / ppo.py View on Github external
@jit
def ppo_loss_given_predictions(log_probab_actions_new,
                               log_probab_actions_old,
                               value_predictions_old,
                               padded_actions,
                               rewards_to_actions,
                               padded_rewards,
                               reward_mask,
                               gamma,
                               lambda_,
                               epsilon):
  """PPO objective, with an eventual minus sign, given predictions."""
  B, RT = padded_rewards.shape  # pylint: disable=invalid-name
  _, AT, A = log_probab_actions_old.shape  # pylint: disable=invalid-name

  assert (B, RT) == padded_rewards.shape
  assert (B, AT) == padded_actions.shape
github tensorflow / tensor2tensor / tensor2tensor / trax / rl / ppo.py View on Github external
@jit
def ppo_loss_given_predictions(log_probab_actions_new,
                               log_probab_actions_old,
                               value_predictions_old,
                               padded_actions,
                               rewards_to_actions,
                               padded_rewards,
                               reward_mask,
                               gamma=0.99,
                               lambda_=0.95,
                               epsilon=0.2):
  """PPO objective, with an eventual minus sign, given predictions."""
  B, RT = padded_rewards.shape  # pylint: disable=invalid-name
  _, AT, A = log_probab_actions_old.shape  # pylint: disable=invalid-name

  assert (B, RT) == padded_rewards.shape
  assert (B, AT) == padded_actions.shape
github pyro-ppl / numpyro / examples / vae.py View on Github external
    @jit
    def epoch_train(svi_state, rng_key):
        def body_fn(i, val):
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fn, (0., svi_state))
github google / jax / jax / experimental / ode.py View on Github external
  @jax.jit
  def _fori_body_fun(i, val):
    """fori_loop function for VJP calculation."""
    rev_yt, rev_t, rev_tarray, rev_gi, vjp_y, vjp_t0, vjp_args, time_vjp_list = val
    this_yt = rev_yt[i, :]
    this_t = rev_t[i]
    this_tarray = rev_tarray[i, :]
    this_gi = rev_gi[i, :]
    # this is g[i-1, :] when g has been reversed
    this_gim1 = rev_gi[i+1, :]
    state_len = this_yt.shape[0]
    vjp_cur_t = np.dot(flat_func(this_yt, this_t, flat_args), this_gi)
    vjp_t0 = vjp_t0 - vjp_cur_t
    # Run augmented system backwards to the previous observation.
    aug_y0 = np.hstack((this_yt, vjp_y, vjp_t0, vjp_args))
    aug_ans = odeint(rev_aug_dynamics,
                     aug_y0,