How to use the trax.backend.numpy.concatenate function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / layers / research / efficient_attention.py View on Github external
random_rotations = self._sample_rotation(rotations_shape, vecs, rng)

    # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
    # hashing, so it's shared between them. Check if that's what we want.
    dropped_vecs = self.drop_for_hash(vecs, subrng)
    rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)

    if self._rehash_each_round:
      if self._factorize_hash and len(factor_list) > 1:
        # We factorized self.n_buckets as the product of factor_list.
        # Get the buckets for them and combine.
        buckets, cur_sum, cur_product = None, 0, 1
        for factor in factor_list:
          rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
          cur_sum += factor // 2
          rv = np.concatenate([rv, -rv], axis=-1)
          if buckets is None:
            buckets = np.argmax(rv, axis=-1)
          else:
            buckets += cur_product * np.argmax(rv, axis=-1)
          cur_product *= factor
      else:
        rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
        buckets = np.argmax(rotated_vecs, axis=-1)
      # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
      # bucket numbers from different hashing rounds don't overlap.
      offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
      offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
      buckets = np.reshape(buckets + offsets, (-1,))
    else:
      assert not self._factorize_hash
      rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def forward(self, inp, weights):
    """Reshape input to have heads dimension and concatenate positions there."""
    x = inp[0]
    n_batches, seqlen = x.shape[0], x.shape[1]
    d_head = x.shape[-1] // self._n_heads
    res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
    res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
    if self._n_pos == 1:  # Just one position given, tile into each head.
      pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
      pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
    else:  # As many positions as heads, concatenate them in.
      pos = [p[:, None, :, :] for p in inp[1:]]
      pos = np.concatenate(pos, axis=1)
    res = np.concatenate([res, pos], axis=-1)
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
    return res
github google / trax / trax / layers / research / efficient_attention.py View on Github external
if self._rehash_each_round:
      if self._factorize_hash and len(factor_list) > 1:
        # We factorized self.n_buckets as the product of factor_list.
        # Get the buckets for them and combine.
        buckets, cur_sum, cur_product = None, 0, 1
        for factor in factor_list:
          rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
          cur_sum += factor // 2
          rv = np.concatenate([rv, -rv], axis=-1)
          if buckets is None:
            buckets = np.argmax(rv, axis=-1)
          else:
            buckets += cur_product * np.argmax(rv, axis=-1)
          cur_product *= factor
      else:
        rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
        buckets = np.argmax(rotated_vecs, axis=-1)
      # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
      # bucket numbers from different hashing rounds don't overlap.
      offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
      offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
      buckets = np.reshape(buckets + offsets, (-1,))
    else:
      assert not self._factorize_hash
      rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
      # In this configuration, we map each item to the top self.n_hashes buckets
      rotated_vecs = np.squeeze(rotated_vecs, 0)
      bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1]))
      bucket_range = np.reshape(bucket_range, (1, -1))
      bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

      _, buckets = jax.lax.sort_key_val(
github google / trax / trax / models / research / reformer.py View on Github external
def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
    del weights, kwargs

    x1_split = []
    x2_split = []
    for y in output:
      y1, y2 = np.split(y, 2, -1)
      x1_split.append(y1)
      x2_split.append(y2)

    x1 = np.concatenate(x1_split, self._axis)
    x2 = np.concatenate(x2_split, self._axis)

    return (x1, x2)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def look_one_back(x):
      # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
        return np.concatenate([x, x_extra], axis=1)
      else:
        assert len(x.shape) == 4
        x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
        return np.concatenate([x, x_extra], axis=2)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def look_one_back(x):
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
      else:
        x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
      return np.concatenate([x, x_extra], axis=1)
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs):
  """Query a table with a position vector."""
  if keys is None:
    return x
  k = np.array(keys)
  v = np.array(values)
  q = x
  if binary:
    q = np.concatenate([x, x], axis=-1)
  return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def look_one_back(x):
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
      else:
        x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
      return np.concatenate([x, x_extra], axis=1)
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
the vector with combined xs and one with combined positions.
  """
  seqlen = x.shape[1]
  d_head = x.shape[2]
  x = np.reshape(x, (-1, n_heads, seqlen, d_head))
  x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
  x = np.reshape(x, (-1, seqlen, n_heads * d_head))
  head_size = int(d_head) - POS_VECTOR_SIZE
  res, positions, idx = [], [], 0
  for _ in range(n_heads):
    res.append(x[:, :, idx:idx+head_size])
    idx += head_size
    positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
    idx += POS_VECTOR_SIZE
  combined_position = sum(positions) / float(len(positions))
  return np.concatenate(res, axis=-1), combined_position
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def look_one_back(x):
      # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
        return np.concatenate([x, x_extra], axis=1)
      else:
        assert len(x.shape) == 4
        x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
        return np.concatenate([x, x_extra], axis=2)