Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def init(self, params):
m = np.zeros_like(params)
v = np.zeros_like(params)
return m, v
forward_slice_with_q_loop_idx, query_slice, key, value)
return output_slice, vjpfun(ct_slice)
q_loop_idx = np.zeros((), dtype=np.int32)
q_loop_max = query.shape[-2]
q_loop_stride = self._loop_stride
if q_loop_max == 1: # For abstract runs with unknown shapes.
q_loop_stride = 1
assert q_loop_max % q_loop_stride == 0, (
'Stride must evenly divide the number of query elements.')
out_accum = np.zeros_like(query)
if do_backprop:
query_ct_accum = np.zeros_like(query)
key_ct_accum = np.zeros_like(key)
value_ct_accum = np.zeros_like(value)
init_vals = (
q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum)
else:
init_vals = (q_loop_idx, out_accum)
def cond_fun(vals): # pylint: disable=invalid-name
q_loop_idx = vals[0]
return jax.lax.lt(q_loop_idx, q_loop_max)
def body_fun(vals): # pylint: disable=invalid-name
"""Compute a slice of the attention mechanism."""
if do_backprop:
(q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum) = vals
else:
def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name
return forward_slice(query_slice, q_loop_idx, key, value)
output_slice, vjpfun = jax.vjp(
forward_slice_with_q_loop_idx, query_slice, key, value)
return output_slice, vjpfun(ct_slice)
q_loop_idx = np.zeros((), dtype=np.int32)
q_loop_max = query.shape[-2]
q_loop_stride = self._loop_stride
if q_loop_max == 1: # For abstract runs with unknown shapes.
q_loop_stride = 1
assert q_loop_max % q_loop_stride == 0, (
'Stride must evenly divide the number of query elements.')
out_accum = np.zeros_like(query)
if do_backprop:
query_ct_accum = np.zeros_like(query)
key_ct_accum = np.zeros_like(key)
value_ct_accum = np.zeros_like(value)
init_vals = (
q_loop_idx, out_accum,
query_ct_accum, key_ct_accum, value_ct_accum)
else:
init_vals = (q_loop_idx, out_accum)
def cond_fun(vals): # pylint: disable=invalid-name
q_loop_idx = vals[0]
return jax.lax.lt(q_loop_idx, q_loop_max)
def body_fun(vals): # pylint: disable=invalid-name
"""Compute a slice of the attention mechanism."""
# The approach here is to perform attention for one batch element and head
# at a time. Note that there is absolutely no interaction across examples or
# heads: this layer has no parameters, and hashing patterns are also
# different across examples/heads. As a result, batching doesn't give any
# performance gains except in the case of accelerator under-utilization. We
# assume that hash-based attention will be applied primarily to long
# sequences, where unbatched attention for a single head has sufficient
# computation to fill up the accelerator.
batch_loop_idx = np.zeros((), dtype=np.int32)
batch_loop_max = qk.shape[0]
init_vals = (batch_loop_idx,)
if return_output:
out_accum = np.zeros_like(qk)
init_vals = init_vals + (out_accum,)
if return_state:
buckets_accum = np.zeros(
[qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
init_vals = init_vals + (buckets_accum,)
if ct is not None:
qk_ct_accum = np.zeros_like(qk)
v_ct_accum = np.zeros_like(v)
init_vals = init_vals + (qk_ct_accum, v_ct_accum)
def cond_fun(vals):
batch_loop_idx = vals[0]
return jax.lax.lt(batch_loop_idx, batch_loop_max)
def body_fun(vals):
"""Performs attention for a single batch element and head."""
def forward_and_backward(self, inputs, ct, state=base.EMPTY_STATE,
new_state=base.EMPTY_STATE, rng=None, **kwargs):
del kwargs
output, _, (qk_ct, v_ct) = self.batch_call_and_or_grad(
inputs[0], inputs[2], ct=ct, new_state=new_state, rng=rng)
return output, (qk_ct, np.zeros_like(inputs[1]), v_ct)
def drop_for_hash(self, x, rng):
rate = self._drop_for_hash_rate
if self._mode == 'train' and rate > 0.0:
keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape)
return np.where(keep, x / (1.0 - rate), np.zeros_like(x))
return x
def _update_sketched(self, grads, weights, m, v, opt_params):
"""Update for higher-rank parameters."""
learning_rate = opt_params['learning_rate']
momentum = opt_params['momentum']
shape = weights.shape
rank = len(shape)
reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
for i in range(rank)]
current_accumulator = self._minimum(reshaped_accumulators)
current_accumulator += grads * grads
accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
1.0 / np.sqrt(current_accumulator),
np.zeros_like(current_accumulator))
preconditioned_gradient = grads * accumulator_inv_sqrt
m = (1.0 - momentum) * preconditioned_gradient + momentum * m
weights = weights - (learning_rate * m).astype(weights.dtype)
for i in range(len(v)):
axes = list(range(int(i))) + list(range(int(i) + 1, rank))
dim_accumulator = np.amax(current_accumulator, axis=axes)
v[i] = dim_accumulator
return weights, (m, v)
if mask is not None:
# TODO(kitaev): workaround for https://github.com/google/jax/issues/850
# We must ensure that both mask and the -1e9 constant have a data dependency
# on the input. Broadcasted copies of these use a lot of memory, so they
# should be computed at runtime (rather than being global constants).
if backend.get_name() == 'jax':
mask = jax.lax.tie_in(dots, mask)
# JAX's `full_like` already ties in -1e9 to dots.
dots = np.where(mask, dots, np.full_like(dots, -1e9))
# Softmax.
dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
if dropout is not None and dropout > 0.0 and mode == 'train':
keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
out = np.matmul(dots, value)
return out