How to use the trax.layers function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / models / research / reformer.py View on Github external
x.shape[0] * n_sections,
      x.shape[1] // n_sections,
      ) + x.shape[2:])


@tl.layer()
def Unchunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[0] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] // n_sections,
      x.shape[1] * n_sections,
      ) + x.shape[2:])


class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial):
  """Half of a RevNet-style residual (only updates part of the hidden state)."""

  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(residual_layers, [], []),
    )

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), [])
    ]
    super(ReversibleHalfResidual, self).__init__(layers)
github google / trax / trax / models / rnn.py View on Github external
def MultiRNNCell():
    """Multi-layer RNN cell."""
    assert n_layers == 2
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )
github google / trax / trax / models / transformer.py View on Github external
d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
  def _Dropout():
    return tl.Dropout(rate=dropout, mode=mode)

  attention_qkv = tl.AttentionQKV(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  causal_attention = tl.CausalAttention(
      d_model, n_heads=n_heads, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [                             # vec_d masks vec_e
      tl.Residual(
          tl.LayerNorm(),              # vec_d ..... .....
          causal_attention,            # vec_d ..... .....
          _Dropout(),                  # vec_d ..... .....
      ),
      tl.Residual(
          tl.LayerNorm(),              # vec_d ..... .....
github google / trax / trax / models / research / reformer.py View on Github external
def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(residual_layers, [], []),
    )

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), [])
    ]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
    self.reverse_layers = [self.compute_residual, self.subtract_top]
github google / trax / trax / models / research / reformer.py View on Github external
def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(residual_layers, [], []),
    )

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), [])
    ]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
    self.reverse_layers = [self.compute_residual, self.subtract_top]
github google / trax / trax / models / resnet.py View on Github external
def ConvBlock(kernel_size, filters, strides, norm, non_linearity,
              mode='train'):
  """ResNet convolutional striding block."""
  # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.
  ks = kernel_size
  filters1, filters2, filters3 = filters
  main = [
      tl.Conv(filters1, (1, 1), strides),
      norm(mode=mode),
      non_linearity(),
      tl.Conv(filters2, (ks, ks), padding='SAME'),
      norm(mode=mode),
      non_linearity(),
      tl.Conv(filters3, (1, 1)),
      norm(mode=mode),
  ]
  shortcut = [
      tl.Conv(filters3, (1, 1), strides),
      norm(mode=mode),
  ]
  return [
      tl.Residual(main, shortcut=shortcut),
      non_linearity()
  ]
github google / trax / trax / models / transformer.py View on Github external
d_ff: int: depth of feed-forward layer
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps vectors to vectors.
  """
  dropout_middle = tl.Dropout(
      rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode)
  dropout_final = tl.Dropout(
      rate=dropout, name='ff_final_%d' % layer_idx, mode=mode)

  return [
      tl.LayerNorm(),
      tl.Dense(d_ff),
      activation(),
      dropout_middle,
      tl.Dense(d_model),
      dropout_final,
  ]
github google / trax / trax / rl / ppo.py View on Github external
  @tl.layer()
  def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
    """Splits logits for actions in different controls and flattens controls."""
    return np.reshape(x, (x.shape[0], -1, n_actions))
github google / trax / trax / models / transformer.py View on Github external
created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
  attention = tl.Attention(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  dropout_ = tl.Dropout(
      rate=dropout, name='dropout_enc_attn', mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward