How to use the trax.layers.base function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / layers / core.py View on Github external
@base.layer()
def FastGelu(x, **unused_kwargs):
  return 0.5 * x * (1 + np.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
github google / trax / trax / layers / normalization.py View on Github external
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Trax normalization layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from trax.backend import numpy as np
from trax.layers import base


class BatchNorm(base.Layer):
  """Batch normalization."""

  def __init__(self, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
               momentum=0.999, mode='train'):
    super(BatchNorm, self).__init__()
    self._axis = axis
    self._epsilon = epsilon
    self._center = center
    self._scale = scale
    self._momentum = momentum
    self._mode = mode

  def new_weights_and_state(self, input_signature):
    """Helper to initialize batch norm weights."""
    axis = self._axis
    axis = (axis,) if np.isscalar(axis) else axis
github google / trax / trax / layers / pooling.py View on Github external
@base.layer()
def SumPool(x, weights, pool_size=(2, 2), strides=None, padding='VALID', **kw):
  del weights, kw
  return backend.sum_pool(x, pool_size=pool_size, strides=strides,
                          padding=padding)
github google / trax / trax / layers / metrics.py View on Github external
@base.layer(n_in=2, n_out=1)
def CrossEntropy(x, axis=-1, **kw):
  del kw
  prediction, target = x
  return np.sum(prediction * core.one_hot(target, prediction.shape[-1]),
                axis=axis)
github google / trax / trax / layers / research / efficient_attention.py View on Github external
  def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
                         state=base.EMPTY_STATE, rng=None, **kwargs):
    del weights, kwargs
    if self._mode in ('train', 'eval'):
      output = self._forward_train_eval(inputs, rng)
      return (output, state)
    else:
      assert self._mode == 'predict'
      return self._forward_predict(inputs, state, rng)
github google / trax / trax / layers / rnn.py View on Github external
memory_transform_fn: Optional transformation on the memory before gating.
    gate_nonlinearity: Function to use as gate activation. Allows trying
      alternatives to Sigmoid, such as HardSigmoid.
    candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
      trying alternatives to traditional Tanh, such as HardTanh
    dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
      best in a GRU when applied exclusively to this branch.
    sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
      off with a positive bias.

  Returns:
    A model representing a GRU cell with specified transforms.
  """
  gate_block = [  # u_t
      candidate_transform(),
      base.Fn(lambda x: x + sigmoid_bias),
      gate_nonlinearity(),
  ]
  reset_block = [  # r_t
      candidate_transform(),
      base.Fn(lambda x: x + sigmoid_bias),  # Want bias to start positive.
      gate_nonlinearity(),
  ]
  candidate_block = [
      cb.Dup(),
      reset_block,
      cb.Multiply(),  # Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
      candidate_transform(),  # Final projection + tanh to get Ct
      candidate_nonlinearity(),  # Candidate gate

      # Only apply dropout on the C gate. Paper reports 0.1 as a good default.
      core.Dropout(rate=dropout_rate_c)
github google / trax / trax / layers / attention.py View on Github external
def new_weights_and_state(self, input_signature):
    d_feature = input_signature.shape[-1]
    pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32)
    position = onp.arange(0, self._max_len)[:, onp.newaxis]
    div_term = onp.exp(
        onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature))
    pe[:, 0::2] = onp.sin(position * div_term)
    pe[:, 1::2] = onp.cos(position * div_term)
    pe = pe[onp.newaxis, :, :]  # [1, self._max_len, d_feature]
    weights = np.array(pe)  # These are trainable parameters, initialized above.
    state = 0 if self._mode == 'predict' else base.EMPTY_STATE
    return weights, state
github google / trax / trax / layers / attention.py View on Github external
  def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
                         state=base.EMPTY_STATE, rng=None, **kwargs):
    del weights
    q, k, v = inputs
    if self._mode in ('train', 'eval'):
      mask_size = q.shape[-2]
      # Not all backends define np.tril. However, using onp.tril is inefficient
      # in that it creates a large global constant. TODO(kitaev): try to find an
      # alternative that works across all backends.
      if backend.get_name() == 'jax':
        mask = np.tril(
            np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
      else:
        mask = onp.tril(
            onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
    else:
      assert self._mode == 'predict'
github google / trax / trax / layers / normalization.py View on Github external
@base.layer(new_weights_fn=_layer_norm_weights)
def LayerNorm(x, weights, epsilon=1e-6, **unused_kwargs):  # pylint: disable=invalid-name
  (scale, bias) = weights
  mean = np.mean(x, axis=-1, keepdims=True)
  variance = np.mean((x - mean)**2, axis=-1, keepdims=True)
  norm_inputs = (x - mean) / np.sqrt(variance + epsilon)
  return norm_inputs * scale + bias