# How to use the trax.backend.numpy.matmul function in trax

## To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

google / trax / trax / layers / research / efficient_attention.py View on Github
``````dots &lt; bdots_thresh[..., None], np.float32)

# Softmax.
dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
dots = np.exp(dots - dots_logsumexp)

if self._dropout &gt; 0.0:
# Dropout is broadcast across the bin dimension
dropout_shape = (1, dots.shape[-2], dots.shape[-1])
keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier

bo = np.matmul(dots, bv)
so = np.reshape(bo, (-1, bo.shape[-1]))
slogits = np.reshape(dots_logsumexp, (-1,))

def unsort_for_output_impl(so, slogits):
o = np.take(so, undo_sort, axis=0)
# Sorting is considerably faster than gather, but first we need to get the
# XLA compiler to abandon the idea of fusing this sort with the input sort
# (which introduces a computation cycle and leads to a crash).
# TODO(kitaev): remove "sticker_" variable if XLA is fixed.
sticker_ = sticker + jax.lax.convert_element_type(
slogits[0] &gt; 0, sticker.dtype)
_, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
return o, logits

def unsort_for_output_vjp(so, slogits):
google / trax / trax / layers / attention.py View on Github
``````"""Core dot product self-attention.

Args:
query: array of representations
key: array of representations
value: array of representations
dropout: float: dropout rate
mode: 'eval' or 'train': whether to use dropout
rng: JAX PRNGKey: subkey for disposable use

Returns:
Self attention for q, k, v arrays.
"""
depth = np.shape(query)[-1]
dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
# JAX's `full_like` already ties in -1e9 to dots.
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))``````
google / trax / trax / layers / attention.py View on Github
``````# TODO(kitaev): workaround for https://github.com/google/jax/issues/850
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
# JAX's `full_like` already ties in -1e9 to dots.
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
out = np.matmul(dots, value)
return out``````
google / trax / trax / layers / research / efficient_attention.py View on Github
``````def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
"""Forward pass for a subset of the query vectors."""
if self._share_qk:
key = self.make_unit_length(key)

dots = np.matmul(
query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth)

dots = dots - 1e9 * mask

# Mask out attention to self except when no other targets are available.
if self._share_qk:
dots = dots - 1e5 * self_mask

# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))

if self.dropout is not None and self.dropout > 0.0:
google / trax / trax / layers / research / efficient_attention.py View on Github
``````self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
dots = dots - 1e5 * self_mask

# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))

if self.dropout > 0.0:
dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1])
keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier

bo = np.matmul(dots, bv)

output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1]))
assert output.shape == v.shape
return output[..., :original_len, :]``````
google / trax / trax / layers / research / efficient_attention.py View on Github
``````dropout_shape = (1, dots.shape[-2], dots.shape[-1])
slice_rng = jax.random.fold_in(rng, q_loop_idx)
keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier

if self._hard_k > 0:
top_k = np.sort(dots)[..., -self._hard_k]  # Get the top-kth weight.
dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).
dots = np.maximum(dots, 0)
dots_sum = np.sum(dots, axis=-1, keepdims=True)  # Re-normalize.
dots /= dots_sum  # Re-normalize.

out_slice = np.matmul(dots, value)
return out_slice``````

Trax

Apache-2.0