How to use the trax.backend.numpy.reshape function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / models / research / View on Github external
def Chunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[1] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] * n_sections,
      x.shape[1] // n_sections,
      ) + x.shape[2:])
github google / trax / trax / supervised / View on Github external
def f(x):
    x_shape = list(x.shape)
    batch_size = x_shape[0]
    batch_size_per_device = batch_size // n_devices
    if batch_size_per_device * n_devices != batch_size:
      raise ValueError(
          'We require that n_devices[%d] divides batch_size[%d] evenly.' %
          (n_devices, batch_size))
    new_shape_prefix = [n_devices, batch_size_per_device]
    return backend.numpy.reshape(x, new_shape_prefix + x_shape[1:])
  return backend.nested_map(f, x)
github google / trax / trax / models / research / View on Github external
def Unchunk(x, weights, n_sections=2, **kwargs):
  del weights, kwargs
  assert x.shape[0] % n_sections == 0
  return np.reshape(x, (
      x.shape[0] // n_sections,
      x.shape[1] * n_sections,
      ) + x.shape[2:])
github google / trax / trax / models / research / View on Github external
"""Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.

  The positions are averaged as vectors.

    x: input vector, concatenated (x0, p0, ..., xH, pH).
    n_heads: number of heads.

    the vector with combined xs and one with combined positions.
  seqlen = x.shape[1]
  d_head = x.shape[2]
  x = np.reshape(x, (-1, n_heads, seqlen, d_head))
  x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
  x = np.reshape(x, (-1, seqlen, n_heads * d_head))
  head_size = int(d_head) - POS_VECTOR_SIZE
  res, positions, idx = [], [], 0
  for _ in range(n_heads):
    res.append(x[:, :, idx:idx+head_size])
    idx += head_size
    positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
    idx += POS_VECTOR_SIZE
  combined_position = sum(positions) / float(len(positions))
  return np.concatenate(res, axis=-1), combined_position
github google / trax / trax / models / research / View on Github external
def CombineHeadsPos(x, n_heads=1, **unused_kwargs):
  """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.

  The positions are averaged as vectors.

    x: input vector, concatenated (x0, p0, ..., xH, pH).
    n_heads: number of heads.

    the vector with combined xs and one with combined positions.
  seqlen = x.shape[1]
  d_head = x.shape[2]
  x = np.reshape(x, (-1, n_heads, seqlen, d_head))
  x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
  x = np.reshape(x, (-1, seqlen, n_heads * d_head))
  head_size = int(d_head) - POS_VECTOR_SIZE
  res, positions, idx = [], [], 0
  for _ in range(n_heads):
    res.append(x[:, :, idx:idx+head_size])
    idx += head_size
    positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
    idx += POS_VECTOR_SIZE
  combined_position = sum(positions) / float(len(positions))
  return np.concatenate(res, axis=-1), combined_position
github google / trax / trax / layers / research / View on Github external
      assert dup_counts.shape == dots.shape
      if self._hard_k > 0:
        dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
        dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))

    # Each query only attends to the top k most relevant keys.
    if self._hard_k > 0:
      b_top_dots = np.sort(dots)[..., -self._hard_k:]  # Get the top k dots.
      b_top_dots = jax.lax.stop_gradient(b_top_dots)
      s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
      top_dots = np.take(s_top_dots, undo_sort, axis=0)

      merged_top_dots = np.moveaxis(
          np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1)
      merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))

      dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
      # It's possible to compute the partition function at this point, but right
      # now this codepath isn't set up for backprop, and there might also be
      # issues computing it this way if two dot-products are exactly equal.

      sdots_thresh = dots_thresh[st]
      bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1))
      bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

      top_k_mask = jax.lax.convert_element_type(
          dots < bdots_thresh[..., None], np.float32)
      dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

    # Softmax.
github google / trax / trax / layers / View on Github external
def multigaussian_loss(preds, targets, ngauss=1):  # pylint: disable=invalid-name
  """Compute mixture of gaussians loss."""
  ndims = targets.shape[-1]
  logits = preds[:, :ngauss]
  mus = preds[:, ngauss:ngauss*(ndims + 1)]
  sigmas = preds[:, ngauss(ndims + 1):]
  sigmas = sigmas * sigmas + 1e-6  # Make positive.
  loglogits = logits - backend.logsumexp(logits, axis=-1, keepdims=True)
  mus = np.reshape(mus, [-1, ngauss, ndims])
  sigmas = np.reshape(sigmas, [-1, ngauss, ndims])
  targets = np.reshape(targets, [-1, 1, ndims])
  glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas)
  return backend.logsumexp(loglogits + glogprobs, axis=-1)
github google / trax / trax / layers / research / View on Github external
def _forward_train_eval(self, inputs, rng):
    (inputs, original_len, n_bins) = self._pad_inputs(inputs)
    q, k, v = inputs
    seqlen = q.shape[-2]
    # q/k/v are n_batch*n_heads, seqlen, d_head
    # Time indices for causal masking.
    t = jax.lax.tie_in(q, np.arange(seqlen))

    # Split off a "bin" axis for chunks of consecutive items.
    bq_t = np.reshape(t, (n_bins, -1))
    bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1]))
    if self._share_qk:
      bk = self.make_unit_length(bq)
      bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1]))
    bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1]))

    # Allow each chunk to attend within itself, and also one chunk back.
    def look_one_back(x):
      # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
        return np.concatenate([x, x_extra], axis=1)
        assert len(x.shape) == 4
        x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
        return np.concatenate([x, x_extra], axis=2)

    bkv_t = look_one_back(bq_t)
    bk = look_one_back(bk)
    bv = look_one_back(bv)
github google / trax / trax / layers / research / View on Github external
# Softmax.
    dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if self._dropout > 0.0:
      # Dropout is broadcast across the bin dimension
      dropout_shape = (1, dots.shape[-2], dots.shape[-1])
      keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
      keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
      multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
      dots = dots * multiplier

    bo = np.matmul(dots, bv)
    so = np.reshape(bo, (-1, bo.shape[-1]))
    slogits = np.reshape(dots_logsumexp, (-1,))

    def unsort_for_output_impl(so, slogits):
      o = np.take(so, undo_sort, axis=0)
      # Sorting is considerably faster than gather, but first we need to get the
      # XLA compiler to abandon the idea of fusing this sort with the input sort
      # (which introduces a computation cycle and leads to a crash).
      # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
      sticker_ = sticker + jax.lax.convert_element_type(
          slogits[0] > 0, sticker.dtype)
      _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
      return o, logits

    def unsort_for_output_vjp(so, slogits):
      """Custom gradient for unsort_for_output."""
      so = jax.lax.stop_gradient(so)
      slogits = jax.lax.stop_gradient(slogits)
github google / trax / trax / supervised / View on Github external
def f(x):
    if len(x.shape) < 2:
      return x  # No extra batch dimension: use devices as batch, so return.
    batch_size = x.shape[0] * x.shape[1]
    return backend.numpy.reshape(x, [batch_size] + list(x.shape[2:]))
  return backend.nested_map(f, x_tuple)