How to use the trax.backend.numpy.zeros function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / optimizers / adafactor.py View on Github external
def init(self, params):
    shape = params.shape
    slots = []
    if self._factored and len(shape) >= 2:
      v_row = np.zeros(shape[:-1], dtype=np.float32)
      v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
      slots.extend([v_row, v_col])
    else:
      v = np.zeros_like(params)
      slots.append(v)
    if self._do_momentum:
      m = np.zeros_like(params)
      slots.append(m)
    return slots
github google / trax / trax / layers / attention.py View on Github external
def _fast_inference_init_state(input_signature, buffer_length):
  """Returns an initial state for causal attention layer fast inference."""
  def zeros_for(batch_size, shape_dtype):
    shape, dtype = shape_dtype.as_tuple()
    depth = shape[-1]
    return np.zeros((batch_size, buffer_length, depth), dtype=dtype)

  batch_size = input_signature[0].shape[0]
  k = zeros_for(batch_size, input_signature[1])
  v = zeros_for(batch_size, input_signature[2])
  mask = np.zeros((batch_size, 1, buffer_length))
  index = 0
  return (k, v, mask, index)
github google / trax / trax / layers / attention.py View on Github external
def zeros_for(batch_size, shape_dtype):
    shape, dtype = shape_dtype.as_tuple()
    depth = shape[-1]
    return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
github google / trax / trax / layers / normalization.py View on Github external
def new_weights(self, input_signature):
    # Usually (B, W, H, C)
    shape = input_signature.shape
    num_channels = shape[-1]

    gamma = np.ones((num_channels,), dtype=np.float32)
    beta = np.zeros((num_channels,), dtype=np.float32)

    epsilon_l = base.EMPTY_WEIGHTS
    if self._learn_epsilon:
      epsilon_l = (self._init_learnt_epsilon,)

    return gamma, beta, epsilon_l
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def NewPositionalEncoding(x, positions=None, **kwargs):
  """Implements new positional encoding."""
  del kwargs
  x_length = np.shape(x)[1]
  pos = np.array(positions)[np.newaxis, :x_length, :]
  pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
  return pos
github google / trax / trax / layers / research / efficient_attention.py View on Github external
assert return_output or ct is not None, 'No work to perform!'
    if new_state is not None and new_state is not base.EMPTY_STATE:
      buckets = new_state
    else:
      buckets = None

    # The approach here is to perform attention for one batch element and head
    # at a time. Note that there is absolutely no interaction across examples or
    # heads: this layer has no parameters, and hashing patterns are also
    # different across examples/heads. As a result, batching doesn't give any
    # performance gains except in the case of accelerator under-utilization. We
    # assume that hash-based attention will be applied primarily to long
    # sequences, where unbatched attention for a single head has sufficient
    # computation to fill up the accelerator.

    batch_loop_idx = np.zeros((), dtype=np.int32)
    batch_loop_max = qk.shape[0]

    init_vals = (batch_loop_idx,)
    if return_output:
      out_accum = np.zeros_like(qk)
      init_vals = init_vals + (out_accum,)
    if return_state:
      buckets_accum = np.zeros(
          [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
      init_vals = init_vals + (buckets_accum,)
    if ct is not None:
      qk_ct_accum = np.zeros_like(qk)
      v_ct_accum = np.zeros_like(v)
      init_vals = init_vals + (qk_ct_accum, v_ct_accum)

    def cond_fun(vals):
github google / trax / trax / layers / normalization.py View on Github external
def new_weights_and_state(self, input_signature):
    """Helper to initialize batch norm weights."""
    axis = self._axis
    axis = (axis,) if np.isscalar(axis) else axis
    input_shape = input_signature.shape
    shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
    beta = np.zeros(shape, dtype='float32') if self._center else ()
    gamma = np.ones(shape, dtype='float32') if self._scale else ()
    def get_stats_axis(i, d):
      if i in axis:
        return 1
      else:
        return d
    stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
    running_mean = np.zeros(stats_shape, dtype=np.float32)
    running_var = np.ones(stats_shape, dtype=np.float32)
    n_batches = np.zeros((), dtype=np.int64)
    weights = (beta, gamma)
    state = (running_mean, running_var, n_batches)
    return weights, state
github google / trax / trax / layers / research / efficient_attention.py View on Github external
# heads: this layer has no parameters, and hashing patterns are also
    # different across examples/heads. As a result, batching doesn't give any
    # performance gains except in the case of accelerator under-utilization. We
    # assume that hash-based attention will be applied primarily to long
    # sequences, where unbatched attention for a single head has sufficient
    # computation to fill up the accelerator.

    batch_loop_idx = np.zeros((), dtype=np.int32)
    batch_loop_max = qk.shape[0]

    init_vals = (batch_loop_idx,)
    if return_output:
      out_accum = np.zeros_like(qk)
      init_vals = init_vals + (out_accum,)
    if return_state:
      buckets_accum = np.zeros(
          [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
      init_vals = init_vals + (buckets_accum,)
    if ct is not None:
      qk_ct_accum = np.zeros_like(qk)
      v_ct_accum = np.zeros_like(v)
      init_vals = init_vals + (qk_ct_accum, v_ct_accum)

    def cond_fun(vals):
      batch_loop_idx = vals[0]
      return jax.lax.lt(batch_loop_idx, batch_loop_max)

    def body_fun(vals):
      """Performs attention for a single batch element and head."""
      batch_loop_idx = vals[0]
      if self._prng is None:
        hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx)
github google / trax / trax / layers / normalization.py View on Github external
"""Helper to initialize batch norm weights."""
    axis = self._axis
    axis = (axis,) if np.isscalar(axis) else axis
    input_shape = input_signature.shape
    shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
    beta = np.zeros(shape, dtype='float32') if self._center else ()
    gamma = np.ones(shape, dtype='float32') if self._scale else ()
    def get_stats_axis(i, d):
      if i in axis:
        return 1
      else:
        return d
    stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
    running_mean = np.zeros(stats_shape, dtype=np.float32)
    running_var = np.ones(stats_shape, dtype=np.float32)
    n_batches = np.zeros((), dtype=np.int64)
    weights = (beta, gamma)
    state = (running_mean, running_var, n_batches)
    return weights, state