Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def Chunk(x, weights, n_sections=2, **kwargs):
del weights, kwargs
assert x.shape[1] % n_sections == 0
return np.reshape(x, (
x.shape[0] * n_sections,
x.shape[1] // n_sections,
) + x.shape[2:])
def f(x):
x_shape = list(x.shape)
batch_size = x_shape[0]
batch_size_per_device = batch_size // n_devices
if batch_size_per_device * n_devices != batch_size:
raise ValueError(
'We require that n_devices[%d] divides batch_size[%d] evenly.' %
(n_devices, batch_size))
new_shape_prefix = [n_devices, batch_size_per_device]
return backend.numpy.reshape(x, new_shape_prefix + x_shape[1:])
return backend.nested_map(f, x)
def Unchunk(x, weights, n_sections=2, **kwargs):
del weights, kwargs
assert x.shape[0] % n_sections == 0
return np.reshape(x, (
x.shape[0] // n_sections,
x.shape[1] * n_sections,
) + x.shape[2:])
"""Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.
The positions are averaged as vectors.
Args:
x: input vector, concatenated (x0, p0, ..., xH, pH).
n_heads: number of heads.
Returns:
the vector with combined xs and one with combined positions.
"""
seqlen = x.shape[1]
d_head = x.shape[2]
x = np.reshape(x, (-1, n_heads, seqlen, d_head))
x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head
x = np.reshape(x, (-1, seqlen, n_heads * d_head))
head_size = int(d_head) - POS_VECTOR_SIZE
res, positions, idx = [], [], 0
for _ in range(n_heads):
res.append(x[:, :, idx:idx+head_size])
idx += head_size
positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
idx += POS_VECTOR_SIZE
combined_position = sum(positions) / float(len(positions))
return np.concatenate(res, axis=-1), combined_position
def CombineHeadsPos(x, n_heads=1, **unused_kwargs):
"""Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.
The positions are averaged as vectors.
Args:
x: input vector, concatenated (x0, p0, ..., xH, pH).
n_heads: number of heads.
Returns:
the vector with combined xs and one with combined positions.
"""
seqlen = x.shape[1]
d_head = x.shape[2]
x = np.reshape(x, (-1, n_heads, seqlen, d_head))
x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head
x = np.reshape(x, (-1, seqlen, n_heads * d_head))
head_size = int(d_head) - POS_VECTOR_SIZE
res, positions, idx = [], [], 0
for _ in range(n_heads):
res.append(x[:, :, idx:idx+head_size])
idx += head_size
positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
idx += POS_VECTOR_SIZE
combined_position = sum(positions) / float(len(positions))
return np.concatenate(res, axis=-1), combined_position
axis=-1)
assert dup_counts.shape == dots.shape
if self._hard_k > 0:
dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
else:
dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))
# Each query only attends to the top k most relevant keys.
if self._hard_k > 0:
b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots.
b_top_dots = jax.lax.stop_gradient(b_top_dots)
s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
top_dots = np.take(s_top_dots, undo_sort, axis=0)
merged_top_dots = np.moveaxis(
np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1)
merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))
dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
# It's possible to compute the partition function at this point, but right
# now this codepath isn't set up for backprop, and there might also be
# issues computing it this way if two dot-products are exactly equal.
sdots_thresh = dots_thresh[st]
bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1))
bdots_thresh = jax.lax.stop_gradient(bdots_thresh)
top_k_mask = jax.lax.convert_element_type(
dots < bdots_thresh[..., None], np.float32)
dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)
# Softmax.
def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name
"""Compute mixture of gaussians loss."""
ndims = targets.shape[-1]
logits = preds[:, :ngauss]
mus = preds[:, ngauss:ngauss*(ndims + 1)]
sigmas = preds[:, ngauss(ndims + 1):]
sigmas = sigmas * sigmas + 1e-6 # Make positive.
loglogits = logits - backend.logsumexp(logits, axis=-1, keepdims=True)
mus = np.reshape(mus, [-1, ngauss, ndims])
sigmas = np.reshape(sigmas, [-1, ngauss, ndims])
targets = np.reshape(targets, [-1, 1, ndims])
glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas)
return backend.logsumexp(loglogits + glogprobs, axis=-1)
def _forward_train_eval(self, inputs, rng):
(inputs, original_len, n_bins) = self._pad_inputs(inputs)
q, k, v = inputs
seqlen = q.shape[-2]
# q/k/v are n_batch*n_heads, seqlen, d_head
# Time indices for causal masking.
t = jax.lax.tie_in(q, np.arange(seqlen))
# Split off a "bin" axis for chunks of consecutive items.
bq_t = np.reshape(t, (n_bins, -1))
bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1]))
if self._share_qk:
bk = self.make_unit_length(bq)
else:
bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1]))
bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1]))
# Allow each chunk to attend within itself, and also one chunk back.
def look_one_back(x):
# Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
if len(x.shape) == 2:
x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
return np.concatenate([x, x_extra], axis=1)
else:
assert len(x.shape) == 4
x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1)
return np.concatenate([x, x_extra], axis=2)
bkv_t = look_one_back(bq_t)
bk = look_one_back(bk)
bv = look_one_back(bv)
# Softmax.
dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
dots = np.exp(dots - dots_logsumexp)
if self._dropout > 0.0:
# Dropout is broadcast across the bin dimension
dropout_shape = (1, dots.shape[-2], dots.shape[-1])
keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob)
dots = dots * multiplier
bo = np.matmul(dots, bv)
so = np.reshape(bo, (-1, bo.shape[-1]))
slogits = np.reshape(dots_logsumexp, (-1,))
def unsort_for_output_impl(so, slogits):
o = np.take(so, undo_sort, axis=0)
# Sorting is considerably faster than gather, but first we need to get the
# XLA compiler to abandon the idea of fusing this sort with the input sort
# (which introduces a computation cycle and leads to a crash).
# TODO(kitaev): remove "sticker_" variable if XLA is fixed.
sticker_ = sticker + jax.lax.convert_element_type(
slogits[0] > 0, sticker.dtype)
_, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
return o, logits
def unsort_for_output_vjp(so, slogits):
"""Custom gradient for unsort_for_output."""
so = jax.lax.stop_gradient(so)
slogits = jax.lax.stop_gradient(slogits)
def f(x):
if len(x.shape) < 2:
return x # No extra batch dimension: use devices as batch, so return.
batch_size = x.shape[0] * x.shape[1]
return backend.numpy.reshape(x, [batch_size] + list(x.shape[2:]))
return backend.nested_map(f, x_tuple)