Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
while factor > 0 and not (
self.n_buckets % factor == 0 and
factor % 2 == 0 and
(self.n_buckets // factor) % 2 == 0):
factor -= 1
if factor > 2: # Factor of 2 does not warrant the effort.
rot_size = factor + (self.n_buckets // factor)
factor_list = [factor, self.n_buckets // factor]
rotations_shape = (
vecs.shape[-1],
self.n_hashes if self._rehash_each_round else 1,
rot_size // 2)
rng = jax.lax.tie_in(vecs, rng)
rng, subrng = backend.random.split(rng)
random_rotations = self._sample_rotation(rotations_shape, vecs, rng)
# TODO(lukaszkaiser): the dropout mask will be used for all rounds of
# hashing, so it's shared between them. Check if that's what we want.
dropped_vecs = self.drop_for_hash(vecs, subrng)
rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
def Chunk(x, weights, n_sections=2, **kwargs):
del weights, kwargs
assert x.shape[1] % n_sections == 0
return np.reshape(x, (
x.shape[0] * n_sections,
x.shape[1] // n_sections,
) + x.shape[2:])
def f(x):
x_shape = list(x.shape)
batch_size = x_shape[0]
batch_size_per_device = batch_size // n_devices
if batch_size_per_device * n_devices != batch_size:
raise ValueError(
'We require that n_devices[%d] divides batch_size[%d] evenly.' %
(n_devices, batch_size))
new_shape_prefix = [n_devices, batch_size_per_device]
return backend.numpy.reshape(x, new_shape_prefix + x_shape[1:])
return backend.nested_map(f, x)
def Unchunk(x, weights, n_sections=2, **kwargs):
del weights, kwargs
assert x.shape[0] % n_sections == 0
return np.reshape(x, (
x.shape[0] // n_sections,
x.shape[1] * n_sections,
) + x.shape[2:])
x.shape[0] * n_sections,
x.shape[1] // n_sections,
) + x.shape[2:])
@tl.layer()
def Unchunk(x, weights, n_sections=2, **kwargs):
del weights, kwargs
assert x.shape[0] % n_sections == 0
return np.reshape(x, (
x.shape[0] // n_sections,
x.shape[1] * n_sections,
) + x.shape[2:])
class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial):
"""Half of a RevNet-style residual (only updates part of the hidden state)."""
def __init__(self, residual_layers):
self.compute_residual = tl.Serial(
# (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
tl.Parallel([], tl.Dup()),
tl.Swap(),
tl.Parallel(residual_layers, [], []),
)
layers = [
self.compute_residual,
tl.Parallel(tl.Add(), [])
]
super(ReversibleHalfResidual, self).__init__(layers)
random_rotations = self._sample_rotation(rotations_shape, vecs, rng)
# TODO(lukaszkaiser): the dropout mask will be used for all rounds of
# hashing, so it's shared between them. Check if that's what we want.
dropped_vecs = self.drop_for_hash(vecs, subrng)
rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
if buckets is None:
buckets = np.argmax(rv, axis=-1)
else:
buckets += cur_product * np.argmax(rv, axis=-1)
cur_product *= factor
else:
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
buckets = np.argmax(rotated_vecs, axis=-1)
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that
# bucket numbers from different hashing rounds don't overlap.
offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
buckets = np.reshape(buckets + offsets, (-1,))
else:
assert not self._factorize_hash
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
def forward(self, inp, weights):
"""Reshape input to have heads dimension and concatenate positions there."""
x = inp[0]
n_batches, seqlen = x.shape[0], x.shape[1]
d_head = x.shape[-1] // self._n_heads
res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
res = np.transpose(res, (0, 2, 1, 3)) # (batch, heads, len, depth)
if self._n_pos == 1: # Just one position given, tile into each head.
pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
pos = inp[1][:, None, :, :] + np.zeros(pos_shape) # Add 0 to broadcast.
else: # As many positions as heads, concatenate them in.
pos = [p[:, None, :, :] for p in inp[1:]]
pos = np.concatenate(pos, axis=1)
res = np.concatenate([res, pos], axis=-1)
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
return res
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
if buckets is None:
buckets = np.argmax(rv, axis=-1)
else:
buckets += cur_product * np.argmax(rv, axis=-1)
cur_product *= factor
else:
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
buckets = np.argmax(rotated_vecs, axis=-1)
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that
# bucket numbers from different hashing rounds don't overlap.
offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
buckets = np.reshape(buckets + offsets, (-1,))
else:
assert not self._factorize_hash
rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
# In this configuration, we map each item to the top self.n_hashes buckets
rotated_vecs = np.squeeze(rotated_vecs, 0)
bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1]))
bucket_range = np.reshape(bucket_range, (1, -1))
bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)
_, buckets = jax.lax.sort_key_val(
def PerformPositionOperations(pos, positions=None):
"""Gets pos and returns (q1, ..., q5)."""
succ_keys = positions[:-1, :]
succ_values = positions[1:, :]
subtract_1_keys = positions[1:, :]
subtract_1_values = positions[:-1, :]
l = int(positions.shape[0]) // 2
add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
for i in range(l) for j in range(l)])
add_values = np.array([positions[i + j, :]
for i in range(l) for j in range(l)])
# TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
for j in range(l) for i in range(l)])
sub_values = np.array([positions[max(i - j, 0), :]
for j in range(l) for i in range(l)])
query_types = [
QueryPositionKV(),
QueryPositionKV(keys=succ_keys, values=succ_values),
QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
QueryPositionKV(keys=add_keys, values=add_values, binary=True),
QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
return [qt @ pos for qt in query_types] # pylint: disable=syntax-error
def PerformPositionOperations(pos, positions=None):
"""Gets pos and returns (q1, ..., q5)."""
succ_keys = positions[:-1, :]
succ_values = positions[1:, :]
subtract_1_keys = positions[1:, :]
subtract_1_values = positions[:-1, :]
l = int(positions.shape[0]) // 2
add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
for i in range(l) for j in range(l)])
add_values = np.array([positions[i + j, :]
for i in range(l) for j in range(l)])
# TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
for j in range(l) for i in range(l)])
sub_values = np.array([positions[max(i - j, 0), :]
for j in range(l) for i in range(l)])
query_types = [
QueryPositionKV(),
QueryPositionKV(keys=succ_keys, values=succ_values),
QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
QueryPositionKV(keys=add_keys, values=add_values, binary=True),
QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
return [qt @ pos for qt in query_types] # pylint: disable=syntax-error