How to use the trax.backend.numpy.sqrt function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / layers / initializers.py View on Github external
"""Returns random values for initializing weights of the given `shape`."""
    fan_in, fan_out = _GetFans(shape, out_dim, in_dim)
    gain = scale
    if mode == 'fan_in':
      gain /= fan_in
    elif mode == 'fan_out':
      gain /= fan_out
    elif mode == 'fan_avg':
      gain /= (fan_in + fan_out) / 2
    if distribution == 'truncated_normal':
      # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
      stddev = np.sqrt(gain) / .87962566103423978
      new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev
      return new_weights.astype('float32')
    elif distribution == 'normal':
      new_weights = random.normal(rng, shape) * np.sqrt(gain)
      return new_weights.astype('float32')
    elif distribution == 'uniform':
      lim = np.sqrt(3. * gain)
      return random.uniform(rng, shape, np.float32, -lim, lim)
    else:
      raise ValueError('invalid distribution for ScaleInitializer')
github google / trax / trax / learning_rate.py View on Github external
def learning_rate(step):  # pylint: disable=invalid-name
    """Step to learning rate function."""
    ret = 1.0
    for name in factors:
      if name == 'constant':
        ret *= constant
      elif name == 'linear_warmup':
        ret *= np.minimum(1.0, step / warmup_steps)
      elif name == 'rsqrt_decay':
        ret /= np.sqrt(np.maximum(step, warmup_steps))
      elif name == 'rsqrt_normalized_decay':
        ret *= np.sqrt(warmup_steps)
        ret /= np.sqrt(np.maximum(step, warmup_steps))
      elif name == 'decay_every':
        ret *= (decay_factor ** (step//steps_per_decay))
      elif name == 'cosine_decay':
        progress = np.maximum(
            0.0, (step - warmup_steps) / float(steps_per_cycle))
        ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
      else:
        raise ValueError('Unknown factor %s.' % name)
    ret = np.asarray(ret, dtype=np.float32)
    return {'learning_rate': ret}
github google / trax / trax / layers / research / efficient_attention.py View on Github external
# same bucket, so this increases the chances of attending to relevant items.
    # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
    def look_one_back(x):
      if len(x.shape) == 2:
        x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
      else:
        x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
      return np.concatenate([x, x_extra], axis=1)

    bk = look_one_back(bk)
    bv = look_one_back(bv)
    bkv_t = look_one_back(bkv_t)
    bkv_buckets = look_one_back(bkv_buckets)

    # Dot-product attention.
    dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

    # Causal masking
    mask = jax.lax.convert_element_type(
        jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]),
        np.float32)
    dots = dots - 1e9 * mask

    # Mask out attention to self except when no other targets are available.
    self_mask = jax.lax.convert_element_type(
        jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]),
        np.float32)
    dots = dots - 1e5 * self_mask

    # Mask out attention to other hash buckets.
    if not self._attend_across_buckets:
      bucket_mask = jax.lax.convert_element_type(
github google / trax / trax / optimizers / base.py View on Github external
def l2_norm(tree):
  """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
  leaves = tree_flatten(tree)
  return np.sqrt(sum(np.vdot(x, x) for x in leaves))
github google / trax / trax / optimizers / sm3.py View on Github external
def _update_diagonal(self, grads, weights, m, v, opt_params):
    learning_rate = opt_params['learning_rate']
    momentum = opt_params['momentum']
    v[0] += grads * grads
    preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                              np.zeros_like(v[0]))
    preconditioned_grads = preconditioner * grads
    m = (1 - momentum) * preconditioned_grads + momentum * m
    weights = weights - (learning_rate * m).astype(weights.dtype)
    return weights, (m, v)
github google / trax / trax / layers / initializers.py View on Github external
if mode == 'fan_in':
      gain /= fan_in
    elif mode == 'fan_out':
      gain /= fan_out
    elif mode == 'fan_avg':
      gain /= (fan_in + fan_out) / 2
    if distribution == 'truncated_normal':
      # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
      stddev = np.sqrt(gain) / .87962566103423978
      new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev
      return new_weights.astype('float32')
    elif distribution == 'normal':
      new_weights = random.normal(rng, shape) * np.sqrt(gain)
      return new_weights.astype('float32')
    elif distribution == 'uniform':
      lim = np.sqrt(3. * gain)
      return random.uniform(rng, shape, np.float32, -lim, lim)
    else:
      raise ValueError('invalid distribution for ScaleInitializer')
github google / trax / trax / learning_rate.py View on Github external
def learning_rate(step):  # pylint: disable=invalid-name
    """Step to learning rate function."""
    ret = 1.0
    for name in factors:
      if name == 'constant':
        ret *= constant
      elif name == 'linear_warmup':
        ret *= np.minimum(1.0, step / warmup_steps)
      elif name == 'rsqrt_decay':
        ret /= np.sqrt(np.maximum(step, warmup_steps))
      elif name == 'rsqrt_normalized_decay':
        ret *= np.sqrt(warmup_steps)
        ret /= np.sqrt(np.maximum(step, warmup_steps))
      elif name == 'decay_every':
        ret *= (decay_factor ** (step//steps_per_decay))
      elif name == 'cosine_decay':
        progress = np.maximum(
            0.0, (step - warmup_steps) / float(steps_per_cycle))
        ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
      else:
        raise ValueError('Unknown factor %s.' % name)
    ret = np.asarray(ret, dtype=np.float32)
    return {'learning_rate': ret}