How to use trax - 10 common examples

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / layers / research / efficient_attention.py View on Github external
while factor > 0 and not (
            self.n_buckets % factor == 0 and
            factor % 2 == 0 and
            (self.n_buckets // factor) % 2 == 0):
          factor -= 1
        if factor > 2:  # Factor of 2 does not warrant the effort.
          rot_size = factor + (self.n_buckets // factor)
          factor_list = [factor, self.n_buckets // factor]

    rotations_shape = (
        vecs.shape[-1],
        self.n_hashes if self._rehash_each_round else 1,
        rot_size // 2)

    rng = jax.lax.tie_in(vecs, rng)
    rng, subrng = backend.random.split(rng)
    random_rotations = self._sample_rotation(rotations_shape, vecs, rng)

    # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
    # hashing, so it's shared between them. Check if that's what we want.
    dropped_vecs = self.drop_for_hash(vecs, subrng)
    rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)

    if self._rehash_each_round:
      if self._factorize_hash and len(factor_list) > 1:
        # We factorized self.n_buckets as the product of factor_list.
        # Get the buckets for them and combine.
        buckets, cur_sum, cur_product = None, 0, 1
        for factor in factor_list:
          rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
          cur_sum += factor // 2
          rv = np.concatenate([rv, -rv], axis=-1)
github google / trax / trax / models / research / reformer.py View on Github external
def Chunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[1] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] * n_sections,
      x.shape[1] // n_sections,
      ) + x.shape[2:])
github google / trax / trax / supervised / trainer_lib.py View on Github external
def f(x):
    x_shape = list(x.shape)
    batch_size = x_shape[0]
    batch_size_per_device = batch_size // n_devices
    if batch_size_per_device * n_devices != batch_size:
      raise ValueError(
          'We require that n_devices[%d] divides batch_size[%d] evenly.' %
          (n_devices, batch_size))
    new_shape_prefix = [n_devices, batch_size_per_device]
    return backend.numpy.reshape(x, new_shape_prefix + x_shape[1:])
  return backend.nested_map(f, x)
github google / trax / trax / models / research / reformer.py View on Github external
def Unchunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[0] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] // n_sections,
      x.shape[1] * n_sections,
      ) + x.shape[2:])
github google / trax / trax / models / research / reformer.py View on Github external
x.shape[0] * n_sections,
      x.shape[1] // n_sections,
      ) + x.shape[2:])


@tl.layer()
def Unchunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[0] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] // n_sections,
      x.shape[1] * n_sections,
      ) + x.shape[2:])


class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial):
  """Half of a RevNet-style residual (only updates part of the hidden state)."""

  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(residual_layers, [], []),
    )

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), [])
    ]
    super(ReversibleHalfResidual, self).__init__(layers)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
random_rotations = self._sample_rotation(rotations_shape, vecs, rng)

    # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
    # hashing, so it's shared between them. Check if that's what we want.
    dropped_vecs = self.drop_for_hash(vecs, subrng)
    rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)

    if self._rehash_each_round:
      if self._factorize_hash and len(factor_list) > 1:
        # We factorized self.n_buckets as the product of factor_list.
        # Get the buckets for them and combine.
        buckets, cur_sum, cur_product = None, 0, 1
        for factor in factor_list:
          rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
          cur_sum += factor // 2
          rv = np.concatenate([rv, -rv], axis=-1)
          if buckets is None:
            buckets = np.argmax(rv, axis=-1)
          else:
            buckets += cur_product * np.argmax(rv, axis=-1)
          cur_product *= factor
      else:
        rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
        buckets = np.argmax(rotated_vecs, axis=-1)
      # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
      # bucket numbers from different hashing rounds don't overlap.
      offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
      offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
      buckets = np.reshape(buckets + offsets, (-1,))
    else:
      assert not self._factorize_hash
      rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def forward(self, inp, weights):
    """Reshape input to have heads dimension and concatenate positions there."""
    x = inp[0]
    n_batches, seqlen = x.shape[0], x.shape[1]
    d_head = x.shape[-1] // self._n_heads
    res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
    res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
    if self._n_pos == 1:  # Just one position given, tile into each head.
      pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
      pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
    else:  # As many positions as heads, concatenate them in.
      pos = [p[:, None, :, :] for p in inp[1:]]
      pos = np.concatenate(pos, axis=1)
    res = np.concatenate([res, pos], axis=-1)
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
    return res
github google / trax / trax / layers / research / efficient_attention.py View on Github external
if self._rehash_each_round:
      if self._factorize_hash and len(factor_list) > 1:
        # We factorized self.n_buckets as the product of factor_list.
        # Get the buckets for them and combine.
        buckets, cur_sum, cur_product = None, 0, 1
        for factor in factor_list:
          rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
          cur_sum += factor // 2
          rv = np.concatenate([rv, -rv], axis=-1)
          if buckets is None:
            buckets = np.argmax(rv, axis=-1)
          else:
            buckets += cur_product * np.argmax(rv, axis=-1)
          cur_product *= factor
      else:
        rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
        buckets = np.argmax(rotated_vecs, axis=-1)
      # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
      # bucket numbers from different hashing rounds don't overlap.
      offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
      offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
      buckets = np.reshape(buckets + offsets, (-1,))
    else:
      assert not self._factorize_hash
      rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
      # In this configuration, we map each item to the top self.n_hashes buckets
      rotated_vecs = np.squeeze(rotated_vecs, 0)
      bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1]))
      bucket_range = np.reshape(bucket_range, (1, -1))
      bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

      _, buckets = jax.lax.sort_key_val(
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def PerformPositionOperations(pos, positions=None):
  """Gets pos and returns (q1, ..., q5)."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  query_types = [
      QueryPositionKV(),
      QueryPositionKV(keys=succ_keys, values=succ_values),
      QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
      QueryPositionKV(keys=add_keys, values=add_values, binary=True),
      QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
  return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def PerformPositionOperations(pos, positions=None):
  """Gets pos and returns (q1, ..., q5)."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  query_types = [
      QueryPositionKV(),
      QueryPositionKV(keys=succ_keys, values=succ_values),
      QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
      QueryPositionKV(keys=add_keys, values=add_values, binary=True),
      QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
  return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error