How to use the trax.backend function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / supervised / trainer_lib.py View on Github external
# Setup state.
    rng, init_rng = jax_random.split(rng)
    self._rngs = np.stack(jax_random.split(rng, self._n_devices))
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
      model_target_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.target_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
      model_target_shape = tuple([None] + list(inputs.target_shape))
    # Change all None to 1 in input and target shape.
    model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape)
    model_target_shape = backend.nested_map(lambda x: x or 1,
                                            model_target_shape)

    def new_opt_state_and_model_state(input_shape, input_dtype, target_shape,
                                      target_dtype, rng):
      """Returns optimizer and model states suitable for training a model."""
      # Combine inputs and targets on the stack.
      if not isinstance(input_dtype, (list, tuple)):
        input_dtype = [input_dtype]
        input_shape = [input_shape]
      if not isinstance(target_dtype, (list, tuple)):
        target_dtype = [target_dtype]
        target_shape = [target_shape]
      dtypes = list(input_dtype) + list(target_dtype)
      shapes = list(input_shape) + list(target_shape)
      if self._has_weights:
github google / trax / trax / supervised / trainer_lib.py View on Github external
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
  """Returns a (JIT-compiled) function that computes updates for one step."""
  model_and_loss = tl.Serial(predict_fn, loss_fn)
  # Gradients are always wrt. the first argument, so putting weights first.
  def model_and_loss_call(weights, batch, state, rng):
    res = model_and_loss(batch, weights=weights, state=state, rng=rng)
    return res, model_and_loss.state
  if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.
    def single_update(i, opt_state, batch, state, rng):
      weights, slots, opt_params = opt_state
      rng, subrng = jax_random.split(rng[0])
      grad_fn = backend.grad(model_and_loss_call, has_aux=True)
      grads, state = grad_fn(weights, batch, state, rng)
      return optimizer.tree_update(
          i, grads, weights, slots, opt_params), state, [subrng]
    return backend.jit(single_update) if jit else single_update

  # Else, for n_devices > 1:
  @functools.partial(backend.pmap, axis_name='batch')
  def mapped_update(i, opt_state, batch, state, rng):
    """This is a multi-device version of the update function above."""
    # We assume all tensors have the first dimension = n_devices.
    weights, slots, opt_params = opt_state
    rng, subrng = jax_random.split(rng)
    grad_fn = backend.grad(model_and_loss_call, has_aux=True)
    grads, state = grad_fn(weights, batch, state, rng)
    # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
    # the number of devices on this host machine, however psum goes over all
    # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
    # of them.
    grads = jax.tree_util.tree_map(
        lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads)
github google / trax / trax / layers / combinators.py View on Github external
def forward(self, inputs, weights):
    del weights
    return tuple(backend.numpy.split(inputs, self._n_items, self._axis))
github google / trax / trax / layers / attention.py View on Github external
dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
  if mask is not None:
    # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
    # We must ensure that both mask and the -1e9 constant have a data dependency
    # on the input. Broadcasted copies of these use a lot of memory, so they
    # should be computed at runtime (rather than being global constants).
    if backend.get_name() == 'jax':
      mask = jax.lax.tie_in(dots, mask)
    # JAX's `full_like` already ties in -1e9 to dots.
    dots = np.where(mask, dots, np.full_like(dots, -1e9))
  # Softmax.
  dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
  out = np.matmul(dots, value)
  return out
github google / trax / trax / supervised / trainer_lib.py View on Github external
def _combine_devices(x_tuple):
  """Combine multi-device tensors into a single batch."""
  def f(x):
    if len(x.shape) < 2:
      return x  # No extra batch dimension: use devices as batch, so return.
    batch_size = x.shape[0] * x.shape[1]
    return backend.numpy.reshape(x, [batch_size] + list(x.shape[2:]))
  return backend.nested_map(f, x_tuple)
github google / trax / trax / layers / combinators.py View on Github external
def forward(self, xs, weights):
    del weights
    return backend.numpy.concatenate(xs, self._axis)
github google / trax / trax / layers / convolution.py View on Github external
def forward(self, x, weights):
    w, b = weights
    x_shape = list(x.shape)
    if len(x_shape) > 4:
      self._check_nhwc()
      new_batch_dim = six.moves.reduce(operator.mul, x_shape[:-3])
      x = np.reshape(x, [new_batch_dim] + x_shape[-3:])
    res = backend.conv(
        x, w, self._strides, self._padding, self._dimension_numbers,
        self._one) + b
    if len(x_shape) > 4:
      res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:]))
    return res
github google / trax / trax / layers / research / efficient_attention.py View on Github external
def forward_and_backward(self, inputs, ct, state, new_state, **kwargs):
    assert backend.get_name() == 'jax', (
        'JAX backend is required to use forward_and_backward.')
    # Simultaneous forward pass and backprop through the attention mechanism.
    def _do_forward(x):  # pylint: disable=invalid-name
      res, _ = self.forward_with_state(x, state=state, **kwargs)
      return res
    output, vjpfun = jax.vjp(_do_forward, inputs)
    return output, vjpfun(ct)[0]
github google / trax / trax / layers / pooling.py View on Github external
def MaxPool(x, weights, pool_size=(2, 2), strides=None, padding='VALID', **kw):
  del weights, kw
  return backend.max_pool(x, pool_size=pool_size, strides=strides,
                          padding=padding)
github google / trax / trax / layers / attention.py View on Github external
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
                         state=base.EMPTY_STATE, rng=None, **kwargs):
    if self._mode in ('train', 'eval'):
      x = inputs
      symbol_size = np.shape(x)[1]
      px = weights[:, :symbol_size, :]
      if self._dropout == 0:
        return (x + px, state)
      else:
        noise_shape = list(px.shape)
        for dim in self._dropout_broadcast_dims:
          noise_shape[dim] = 1
        keep_prob = 1.0 - self._dropout
        if backend.get_name() == 'jax':
          keep_prob = jax.lax.tie_in(x, np.full((), keep_prob, dtype=x.dtype))
        keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
        multiplier = keep.astype(x.dtype) / keep_prob
        return (x + px * multiplier, state)
    else:
      assert self._mode == 'predict'
      assert self._dropout == 0
      # State in this class is only used for fast inference. In that case,
      # the model is called with consecutive elements position-by-position.
      # This positional encoding layer needs to store the index of the current
      # position then and increment it on each call -- that's how state is used
      # and updated below.
      return (inputs + np.expand_dims(weights[:, state, :], 1), state + 1)