How to use the trax.backend.numpy.zeros_like function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / optimizers / adam.py View on Github external
def init(self, params):
    m = np.zeros_like(params)
    v = np.zeros_like(params)
    return m, v
github google / trax / trax / layers / research / efficient_attention.py View on Github external
forward_slice_with_q_loop_idx, query_slice, key, value)
      return output_slice, vjpfun(ct_slice)

    q_loop_idx = np.zeros((), dtype=np.int32)
    q_loop_max = query.shape[-2]
    q_loop_stride = self._loop_stride
    if q_loop_max == 1:  # For abstract runs with unknown shapes.
      q_loop_stride = 1
    assert q_loop_max % q_loop_stride == 0, (
        'Stride must evenly divide the number of query elements.')

    out_accum = np.zeros_like(query)
    if do_backprop:
      query_ct_accum = np.zeros_like(query)
      key_ct_accum = np.zeros_like(key)
      value_ct_accum = np.zeros_like(value)
      init_vals = (
          q_loop_idx, out_accum,
          query_ct_accum, key_ct_accum, value_ct_accum)
    else:
      init_vals = (q_loop_idx, out_accum)

    def cond_fun(vals):  # pylint: disable=invalid-name
      q_loop_idx = vals[0]
      return jax.lax.lt(q_loop_idx, q_loop_max)

    def body_fun(vals):  # pylint: disable=invalid-name
      """Compute a slice of the attention mechanism."""
      if do_backprop:
        (q_loop_idx, out_accum,
         query_ct_accum, key_ct_accum, value_ct_accum) = vals
      else:
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def forward_slice_with_q_loop_idx(query_slice, key, value):  # pylint: disable=invalid-name
        return forward_slice(query_slice, q_loop_idx, key, value)

      output_slice, vjpfun = jax.vjp(
          forward_slice_with_q_loop_idx, query_slice, key, value)
      return output_slice, vjpfun(ct_slice)

    q_loop_idx = np.zeros((), dtype=np.int32)
    q_loop_max = query.shape[-2]
    q_loop_stride = self._loop_stride
    if q_loop_max == 1:  # For abstract runs with unknown shapes.
      q_loop_stride = 1
    assert q_loop_max % q_loop_stride == 0, (
        'Stride must evenly divide the number of query elements.')

    out_accum = np.zeros_like(query)
    if do_backprop:
      query_ct_accum = np.zeros_like(query)
      key_ct_accum = np.zeros_like(key)
      value_ct_accum = np.zeros_like(value)
      init_vals = (
          q_loop_idx, out_accum,
          query_ct_accum, key_ct_accum, value_ct_accum)
    else:
      init_vals = (q_loop_idx, out_accum)

    def cond_fun(vals):  # pylint: disable=invalid-name
      q_loop_idx = vals[0]
      return jax.lax.lt(q_loop_idx, q_loop_max)

    def body_fun(vals):  # pylint: disable=invalid-name
      """Compute a slice of the attention mechanism."""
github google / trax / trax / layers / research / efficient_attention.py View on Github external
# The approach here is to perform attention for one batch element and head
    # at a time. Note that there is absolutely no interaction across examples or
    # heads: this layer has no parameters, and hashing patterns are also
    # different across examples/heads. As a result, batching doesn't give any
    # performance gains except in the case of accelerator under-utilization. We
    # assume that hash-based attention will be applied primarily to long
    # sequences, where unbatched attention for a single head has sufficient
    # computation to fill up the accelerator.

    batch_loop_idx = np.zeros((), dtype=np.int32)
    batch_loop_max = qk.shape[0]

    init_vals = (batch_loop_idx,)
    if return_output:
      out_accum = np.zeros_like(qk)
      init_vals = init_vals + (out_accum,)
    if return_state:
      buckets_accum = np.zeros(
          [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
      init_vals = init_vals + (buckets_accum,)
    if ct is not None:
      qk_ct_accum = np.zeros_like(qk)
      v_ct_accum = np.zeros_like(v)
      init_vals = init_vals + (qk_ct_accum, v_ct_accum)

    def cond_fun(vals):
      batch_loop_idx = vals[0]
      return jax.lax.lt(batch_loop_idx, batch_loop_max)

    def body_fun(vals):
      """Performs attention for a single batch element and head."""
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def forward_and_backward(self, inputs, ct, state=base.EMPTY_STATE,
                           new_state=base.EMPTY_STATE, rng=None, **kwargs):
    del kwargs
    output, _, (qk_ct, v_ct) = self.batch_call_and_or_grad(
        inputs[0], inputs[2], ct=ct, new_state=new_state, rng=rng)
    return output, (qk_ct, np.zeros_like(inputs[1]), v_ct)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def drop_for_hash(self, x, rng):
    rate = self._drop_for_hash_rate
    if self._mode == 'train' and rate > 0.0:
      keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
      return np.where(keep, x / (1.0 - rate), np.zeros_like(x))
    return x
github google / trax / trax / optimizers / sm3.py View on Github external
def _update_sketched(self, grads, weights, m, v, opt_params):
    """Update for higher-rank parameters."""
    learning_rate = opt_params['learning_rate']
    momentum = opt_params['momentum']
    shape = weights.shape
    rank = len(shape)
    reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                             for i in range(rank)]
    current_accumulator = self._minimum(reshaped_accumulators)
    current_accumulator += grads * grads
    accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                    1.0 / np.sqrt(current_accumulator),
                                    np.zeros_like(current_accumulator))
    preconditioned_gradient = grads * accumulator_inv_sqrt
    m = (1.0 - momentum) * preconditioned_gradient + momentum * m
    weights = weights - (learning_rate * m).astype(weights.dtype)
    for i in range(len(v)):
      axes = list(range(int(i))) + list(range(int(i) + 1, rank))
      dim_accumulator = np.amax(current_accumulator, axis=axes)
      v[i] = dim_accumulator
    return weights, (m, v)
github google / trax / trax / layers / attention.py View on Github external
if mask is not None:
    # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
    # We must ensure that both mask and the -1e9 constant have a data dependency
    # on the input. Broadcasted copies of these use a lot of memory, so they
    # should be computed at runtime (rather than being global constants).
    if backend.get_name() == 'jax':
      mask = jax.lax.tie_in(dots, mask)
    # JAX's `full_like` already ties in -1e9 to dots.
    dots = np.where(mask, dots, np.full_like(dots, -1e9))
  # Softmax.
  dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
  out = np.matmul(dots, value)
  return out