How to use the trax.layers.combinators.Serial function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / layers / rnn.py View on Github external
(4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
  # pylint: disable=no-value-for-parameter
  return cb.Serial(                                    # x
      cb.Branch(core.Dense(3 * n_units), []),          # r_f_y, x
      cb.Split(n_items=3),                             # r, f, y, x
      cb.Parallel(core.Sigmoid(), core.Sigmoid()),     # r, f, y, x
      base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
      cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
      cb.Scan(InnerSRUCell(), axis=1),
      cb.Select([0], n_in=2),                          # act(c), r, x
      activation or [],
      base.Fn(lambda c, r, x: c * r + x * (1 - r))
  )
github google / trax / trax / layers / attention.py View on Github external
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention.

  Accepts inputs of the form (x, mask) and constructs (q, k, v) from x.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
  return cb.Serial(
      cb.Dup(), cb.Dup(),
      AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
  )
github google / trax / trax / layers / metrics.py View on Github external
def MaskedScalar(metric_layer, mask_id=None, has_weights=False):
  """Metric as scalar compatible with Trax masking."""
  # Stack of (inputs, targets) --> (metric, weight-mask).
  metric_and_mask = [
      cb.Parallel(
          [],
          cb.Dup()  # Duplicate targets
      ),
      cb.Parallel(
          metric_layer,  # Metric: (inputs, targets) --> metric
          WeightMask(mask_id=mask_id)  # pylint: disable=no-value-for-parameter
      )
  ]
  if not has_weights:
    # Take (metric, weight-mask) and return the weighted mean.
    return cb.Serial(metric_and_mask, WeightedMean())  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      metric_and_mask,
      cb.Parallel(
          [],
          cb.Multiply()  # Multiply given weights by mask_id weights
      ),
      WeightedMean()  # pylint: disable=no-value-for-parameter
  )
github google / trax / trax / layers / combinators.py View on Github external
def _validate(self, layers):
    if not layers or len(layers) < 2:
      raise ValueError(
          'layers ({}) must be a list with at least two elements'.format(
              layers))
    layers = list(layers)  # Ensure we can modify layers.
    for i, obj in enumerate(layers):
      if obj is None or obj == []:  # pylint: disable=g-explicit-bool-comparison
        layers[i] = Serial(None)
      elif isinstance(obj, (list, tuple)):
        layers[i] = Serial(obj)
      else:
        if not isinstance(obj, base.Layer):
          raise ValueError(
              'Found nonlayer object ({}) in layers list: [{}].'.format(
                  obj, layers))
      if layers[i].n_in == 0:
        raise ValueError(
            'Sublayer with n_in = 0 not allowed in Parallel:'
            ' {}'.format(layers[i]))
    return layers
github google / trax / trax / layers / combinators.py View on Github external
- inputs: a, b, c
    - outputs: F(a), G(a, b, c), h1, h2    where h1, h2 = H(a, b)

  As an important special case, a None argument to Branch acts as if it takes
  one argument, which it leaves unchanged. (It acts as a one-arg no-op.)

  Args:
    *layers: list of layers

  Returns:
    the branch layer
  """
  parallel_layer = Parallel(*layers)
  indices = [list(range(layer.n_in)) for layer in parallel_layer.sublayers]
  return Serial(Select(_deep_flatten(indices)), parallel_layer)
github google / trax / trax / layers / metrics.py View on Github external
"""Metric as scalar compatible with Trax masking."""
  # Stack of (inputs, targets) --> (metric, weight-mask).
  metric_and_mask = [
      cb.Parallel(
          [],
          cb.Dup()  # Duplicate targets
      ),
      cb.Parallel(
          metric_layer,  # Metric: (inputs, targets) --> metric
          WeightMask(mask_id=mask_id)  # pylint: disable=no-value-for-parameter
      )
  ]
  if not has_weights:
    # Take (metric, weight-mask) and return the weighted mean.
    return cb.Serial(metric_and_mask, WeightedMean())  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      metric_and_mask,
      cb.Parallel(
          [],
          cb.Multiply()  # Multiply given weights by mask_id weights
      ),
      WeightedMean()  # pylint: disable=no-value-for-parameter
  )
github google / trax / trax / layers / attention.py View on Github external
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention.

  Accepts inputs of the form q, k, v, mask.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
  return cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(  # pylint: disable=no-value-for-parameter
          n_heads=n_heads, dropout=dropout, mode=mode),
      core.Dense(d_feature),
  )