How to use the trax.layers.Serial function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def AttentionPosition(vec, pos,
                      positions=None, d_model=None, n_heads=8,
                      dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention."""

  new_posns = list(LearnedPosOperations(positions=positions,
                                        n_combinations=n_heads) @ (vec, pos))

  hq = tl.Serial(tl.Dense(d_model),
                 CopyPosToHeads(n_heads, tile=False)) @ ([vec,] + new_posns)
  hk = tl.Serial(tl.Dense(d_model),
                 CopyPosToHeads(n_heads, tile=True)) @ (vec, pos)
  hv = tl.ComputeAttentionHeads(
      n_heads=n_heads, d_head=d_model // n_heads) @ vec

  x, pos = tl.Serial(
      tl.DotProductCausalAttention(dropout=dropout, mode=mode),
      CombineHeadsPos(n_heads=n_heads),
      tl.Dense(d_model)) @ (hq, hk, hv)

  return x, pos
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
def AttentionPosition(vec, pos,
                      positions=None, d_model=None, n_heads=8,
                      dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention."""

  new_posns = list(LearnedPosOperations(positions=positions,
                                        n_combinations=n_heads) @ (vec, pos))

  hq = tl.Serial(tl.Dense(d_model),
                 CopyPosToHeads(n_heads, tile=False)) @ ([vec,] + new_posns)
  hk = tl.Serial(tl.Dense(d_model),
                 CopyPosToHeads(n_heads, tile=True)) @ (vec, pos)
  hv = tl.ComputeAttentionHeads(
      n_heads=n_heads, d_head=d_model // n_heads) @ vec

  x, pos = tl.Serial(
      tl.DotProductCausalAttention(dropout=dropout, mode=mode),
      CombineHeadsPos(n_heads=n_heads),
      tl.Dense(d_model)) @ (hq, hk, hv)

  return x, pos
github google / trax / trax / models / rnn.py View on Github external
def MultiRNNCell():
    """Multi-layer RNN cell."""
    assert n_layers == 2
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )
github google / trax / trax / supervised / trainer_lib.py View on Github external
input_dtype = [input_dtype]
        input_shape = [input_shape]
      if not isinstance(target_dtype, (list, tuple)):
        target_dtype = [target_dtype]
        target_shape = [target_shape]
      dtypes = list(input_dtype) + list(target_dtype)
      shapes = list(input_shape) + list(target_shape)
      if self._has_weights:
        shapes += list(target_shape)
        dtypes += [np.float32 for _ in target_dtype]
      input_signature = tuple(ShapeDtype(s, d)
                              for (s, d) in zip(shapes, dtypes))
      # We need to create a new model instance and not reuse `model_train` here,
      # because `m.initialize` puts cached parameter values in `m` and hence the
      # next call of `m.initialize` will give wrong results.
      m = tl.Serial(model(mode='train'), loss_fn)
      m._set_rng_recursive(rng)  # pylint: disable=protected-access
      weights, state = m.init(input_signature)
      (slots, opt_params) = opt.tree_init(weights)
      return (OptState(weights, slots, opt_params), state)
github google / trax / trax / models / transformer.py View on Github external
else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      _EncoderBlock(
          d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
      for i in range(n_encoder_layers)]

  encoder_decoder_blocks = [
      _EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
      for i in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),               # tok_e tok_d tok_d

      # Encode.
      tl.Branch(
          in_encoder, tl.PaddingMask()),  # vec_e masks ..... .....
      encoder_blocks,                     # vec_d masks ..... .....
      tl.LayerNorm(),                     # vec_e ..... ..... .....

      # Decode.
      tl.Select([2, 1, 0]),               # tok_d masks vec_e .....
      tl.ShiftRight(),                    # tok_d ..... ..... .....
      out_encoder,                        # vec_d ..... ..... .....
      tl.Branch(
          [], tl.EncoderDecoderMask()),   # vec_d masks ..... .....
github google / trax / trax / models / research / reformer.py View on Github external
def __init__(self, layer, n_sections=1, check_shapes=True):
    """Initialize the combinator.

    Args:
      layer: a layer to apply to each element.
      n_sections: how many sections to map to (default: 1).
      check_shapes: whether to check that shapes are identical (default: true).

    Returns:
      A new layer representing mapping layer to all elements of the input.
    """
    super(Map, self).__init__(n_in=n_sections, n_out=n_sections)
    if layer is None or isinstance(layer, (list, tuple)):
      layer = tl.Serial(layer)
    self._layer = layer
    # Generally a Map should be applied to lists where all elements have
    # the same shape -- because self._layer will only be initialized once
    # and it could have different parameters for different shapes. But there
    # are valid cases -- e.g., when self._layer has no parameters -- where we
    # can apply Map to different shapes -- set check_shapes=False in such cases.
    self._check_shapes = check_shapes
    self._n_sections = n_sections
github google / trax / trax / rl / ppo.py View on Github external
tl.Dense(n_preds_per_input),
             tl.Flatten()],
        )
    ]
  else:
    layers = [
        bottom_layers_fn(**kwargs),
        tl.Dup(),
        tl.Parallel(
            [tl.Dense(n_preds_per_input * n_actions),
             FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
             tl.LogSoftmax()],
            [tl.Dense(n_preds_per_input), tl.Flatten()],
        )
    ]
  return tl.Serial(layers)
github google / trax / trax / models / atari_cnn.py View on Github external
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'):
  """An Atari CNN."""
  del mode

  # TODO(jonni): Include link to paper?
  # Input shape: (B, T, H, W, C)
  # Output shape: (B, T, output_size)
  return tl.Serial(
      tl.Fn(lambda x: x / 255.0),  # Convert unsigned bytes to float.
      _FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)

      tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
      tl.Relu(),
      tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
      tl.Relu(),
      tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
      tl.Dense(output_size),
      tl.Relu(),
  )
github google / trax / trax / models / research / reformer.py View on Github external
def __init__(self, pre_attention, attention, post_attention):
    self.pre_attention = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(pre_attention, [], []),
    )
    assert hasattr(attention, 'forward_and_backward')
    self.attention = ApplyAttentionWrapper(attention)
    self.post_attention = tl.Parallel(post_attention, [], [])

    layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        tl.Parallel(tl.Add(), []),
    ]
    super(ReversibleAttentionHalfResidual, self).__init__(layers)