How to use the jax.numpy.dot function in jax

To help you get started, we’ve selected a few jax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / examples / mvm.py View on Github external
def compute_element(i):
            return np.dot(rhs, row(i))
        return _chunk_vmap(compute_element, np.arange(rhs.shape[-1]), rhs.shape[-1] // dilation)
github probml / pyprobml / scripts / autodiff_demo_jax.py View on Github external
grad_sum2 = np.dot(np.ones((1,N)), grads)
assert np.allclose(grad_sum, grad_sum2)

# Now make things go fast
from jax import jit

grad_fun = jit(grad(loss))
grads = vmap(partial(grad_fun, w))(X,y)
assert np.allclose(grads, grads2)


# Logistic regression
H1 = hessian(loss)(w, X, y)
mu = predict(w, X)
S = np.diag(mu * (1-mu))
H2 = np.dot(np.dot(X.T, S), X)
assert np.allclose(H1, H2)
github pyro-ppl / numpyro / examples / cg.py View on Github external
def pcg_quad_form_log_det_jvp(primals, tangents):
    A, b, probes, cg_tol, max_iters = primals
    A_dot, b_dot, _, _, _, _ = tangents
    D = b.shape[-1]

    b_probes = np.concatenate([b[None, :], probes])
    Ainv_b_probes, res_norm, iters = pcg_batch_b(b_probes, A, max_iters=max_iters)
    Ainv_b, Ainv_probes = Ainv_b_probes[0], Ainv_b_probes[1:]

    quad_form_dA = -np.dot(Ainv_b, np.matmul(A_dot, Ainv_b))
    quad_form_db = 2.0 * np.dot(Ainv_b, b_dot)
    log_det_dA = np.mean(np.einsum('...i,...i->...', np.matmul(probes, A_dot), Ainv_probes))
    tangent_out = log_det_dA + quad_form_dA + quad_form_db
    quad_form = np.dot(b, Ainv_b)

    return quad_form, tangent_out
github google / jax / jax / experimental / odeint.py View on Github external
Args:
      y0: function value at the start of the interval.
      y1: function value at the end of the interval.
      y_mid: function value at the mid-point of the interval.
      dy0: derivative value at the start of the interval.
      dy1: derivative value at the end of the interval.
      dt: width of the interval.
  Returns:
      Coefficients `[a, b, c, d, e]` for the polynomial
      p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
  """
  v = np.stack([dy0, dy1, y0, y1, y_mid])
  a = np.dot(np.hstack([-2. * dt, 2. * dt, np.array([-8., -8., 16.])]), v)
  b = np.dot(np.hstack([5. * dt, -3. * dt, np.array([18., 14., -32.])]), v)
  c = np.dot(np.hstack([-4. * dt, dt, np.array([-11., -5., 16.])]), v)
  d = dt * dy0
  e = y0
  return a, b, c, d, e
github google / jax / mask.py View on Github external
def matvec(A, b):
  return np.dot(A, b)
github google / trax / trax / rl / ppo_trainer.py View on Github external
nontrainable_params=self._nontrainable_params,
              state=self._model_state,
              rng=k1))
      opt_step += 1
      self._total_opt_step += 1

      # Compute the approx KL for early stopping. Use the whole dataset - as we
      # only do inference, it should fit in the memory.
      (log_probab_actions_new, _) = (
          self._policy_and_value_net_apply(
              padded_observations,
              weights=self._policy_and_value_net_weights,
              state=self._model_state,
              rng=k2))

      action_mask = np.dot(
          np.pad(reward_mask, ((0, 0), (0, 1))), self._rewards_to_actions
      )
      approx_kl = ppo.approximate_kl(log_probab_actions_new, log_probabs_traj,
                                     action_mask)

      early_stopping = approx_kl > 1.5 * self._target_kl
      if early_stopping:
        logging.vlog(
            1, 'Early stopping policy and value optimization after %d steps, '
            'with approx_kl: %0.2f', opt_step, approx_kl)
        # We don't return right-away, we want the below to execute on the last
        # iteration.

      t2 = time.time()
      if (opt_step % self._print_every_optimizer_steps == 0 or
          opt_step == self._n_optimizer_steps or early_stopping):
github tensorflow / tensor2tensor / tensor2tensor / trax / rl / ppo.py View on Github external
value_prediction_old=value_prediction_old,
      epsilon=epsilon)
  (ppo_loss, ppo_summaries) = ppo_loss_given_predictions(
      log_probab_actions_new,
      log_probab_actions_old,
      value_prediction_old,
      padded_actions,
      rewards_to_actions,
      padded_rewards,
      reward_mask,
      gamma=gamma,
      lambda_=lambda_,
      epsilon=epsilon)
  # Pad the reward mask to be compatible with rewards_to_actions.
  padded_reward_mask = np.pad(reward_mask, ((0, 0), (0, 1)))
  action_mask = np.dot(padded_reward_mask, rewards_to_actions)
  entropy_bonus = masked_entropy(log_probab_actions_new, action_mask)
  combined_loss_ = ppo_loss + (c1 * value_loss) - (c2 * entropy_bonus)

  summaries = {
      "combined_loss": combined_loss_,
      "entropy_bonus": entropy_bonus,
  }
  for loss_summaries in (value_summaries, ppo_summaries):
    summaries.update(loss_summaries)

  return (combined_loss_, (ppo_loss, value_loss, entropy_bonus), summaries)
github google / jax / jax / experimental / ode.py View on Github external
Args:
      y0: function value at the start of the interval.
      y1: function value at the end of the interval.
      y_mid: function value at the mid-point of the interval.
      dy0: derivative value at the start of the interval.
      dy1: derivative value at the end of the interval.
      dt: width of the interval.
  Returns:
      Coefficients `[a, b, c, d, e]` for the polynomial
      p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
  """
  v = np.stack([dy0, dy1, y0, y1, y_mid])
  a = np.dot(np.hstack([-2. * dt, 2. * dt, np.array([-8., -8., 16.])]), v)
  b = np.dot(np.hstack([5. * dt, -3. * dt, np.array([18., 14., -32.])]), v)
  c = np.dot(np.hstack([-4. * dt, dt, np.array([-11., -5., 16.])]), v)
  d = dt * dy0
  e = y0
  return a, b, c, d, e
github pyro-ppl / numpyro / examples / cg.py View on Github external
def cg_body_fun(state, mvm):
    x, r, p, r_dot_r, iteration = state
    Ap = mvm(p)
    alpha = r_dot_r / np.dot(p, Ap)
    x = x + alpha * p
    r = r - alpha * Ap
    beta_denom = r_dot_r
    r_dot_r = np.dot(r, r)
    beta = r_dot_r / beta_denom
    p = r + beta * p
    return CGState(x, r, p, r_dot_r, iteration + 1)
github google / trax / trax / rl / ppo.py View on Github external
gamma=gamma)

  # (B, RT)
  advantages = gae_advantages(
      td_deltas, reward_mask, lambda_=lambda_, gamma=gamma)

  # Normalize the advantages.
  advantage_mean = np.mean(advantages)
  advantage_std = np.std(advantages)
  advantages = (advantages - advantage_mean) / (advantage_std + 1e-8)

  # Scatter advantages over padded_actions.
  # rewards_to_actions is RT + 1 -> AT, so we pad the advantages and the reward
  # mask by 1.
  advantages = np.dot(np.pad(advantages, ((0, 0), (0, 1))), rewards_to_actions)
  action_mask = np.dot(
      np.pad(reward_mask, ((0, 0), (0, 1))), rewards_to_actions
  )

  # (B, AT)
  ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old,
                                 padded_actions, action_mask)
  assert (B, AT) == ratios.shape

  # (B, AT)
  objective = clipped_objective(
      ratios, advantages, action_mask, epsilon=epsilon)
  assert (B, AT) == objective.shape

  # ()
  average_objective = np.sum(objective) / np.sum(action_mask)