How to use the trax.layers.Dropout function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / models / rnn.py View on Github external
assert n_layers == 2
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )

  zero_state = tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
      depth_multiplier=n_layers * rnn_cell_d_state_multiplier
  )

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, name='embedding', mode=mode),
      tl.Branch([], zero_state),
      tl.Scan(MultiRNNCell(), axis=1),
      tl.Select([0], n_in=2),  # Drop RNN state.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
tl.Residual(  # Self-attention block.
          tl.LayerNorm(),
          AttentionPosition(positions=positions,
                            d_model=d_model,
                            n_heads=n_heads,
                            dropout=dropout,
                            mode=mode),
          tl.Dropout(rate=dropout, mode=mode)
      ),
      tl.Residual(
          tl.LayerNorm(),
          tl.Dense(d_ff),
          tl.Relu(),
          tl.Dropout(rate=dropout, mode=mode),
          tl.Dense(d_model),
          tl.Dropout(rate=dropout, mode=mode),
      )
github google / trax / trax / models / transformer.py View on Github external
def _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode, activation):
  """Returns a list of layers implementing a feed-forward block.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps vectors to vectors.
  """
  dropout_middle = tl.Dropout(
      rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode)
  dropout_final = tl.Dropout(
      rate=dropout, name='ff_final_%d' % layer_idx, mode=mode)

  return [
      tl.LayerNorm(),
      tl.Dense(d_ff),
      activation(),
      dropout_middle,
      tl.Dense(d_model),
      dropout_final,
  ]
github google / trax / trax / models / transformer.py View on Github external
attn_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers that maps an activation tensor to an activation tensor.
  """
  causal_attention = tl.CausalAttention(
      d_model, n_heads=n_heads, d_attention_key=d_attn_key,
      d_attention_value=d_attn_value, attention_type=attn_type,
      share_qk=share_qk, mode=mode),

  dropout_ = tl.Dropout(
      rate=dropout, name='attention_%d' % layer_idx, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          causal_attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),
github google / trax / trax / models / transformer.py View on Github external
"""Returns a list of layers implementing a feed-forward block.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps vectors to vectors.
  """
  dropout_middle = tl.Dropout(
      rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode)
  dropout_final = tl.Dropout(
      rate=dropout, name='ff_final_%d' % layer_idx, mode=mode)

  return [
      tl.LayerNorm(),
      tl.Dense(d_ff),
      activation(),
      dropout_middle,
      tl.Dense(d_model),
      dropout_final,
  ]
github google / trax / trax / models / research / position_lookup_transformer.py View on Github external
max_len: maximal length
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  positions = _POSITIONS[:max_len, :]

  decoder_blocks = [
      _DecoderBlock(positions, d_model, d_ff, n_heads, dropout, mode)
      for _ in range(n_layers)]

  return tl.Serial(
      tl.ShiftRight(),
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Branch([], NewPositionalEncoding(positions=positions)),
      decoder_blocks,
      tl.Select([0], n_in=2),  # Drop positions.
      tl.LayerNorm(),
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
github google / trax / trax / models / transformer.py View on Github external
def PositionalEncoder(vocab_size):  # tokens --> vectors
    return [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
github google / trax / trax / models / transformer.py View on Github external
Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
  attention = tl.Attention(
      d_model, n_heads=n_heads, dropout=dropout, mode=mode)

  dropout_ = tl.Dropout(
      rate=dropout, name='dropout_enc_attn', mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, layer_idx, mode, ff_activation)

  return [
      tl.Residual(
          tl.LayerNorm(),
          attention,
          dropout_,
      ),
      tl.Residual(
          feed_forward
      ),