Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
dot = torch.bmm(squeries[:, None, :], skeys[:, :, None]).view(b*h,tp*vs)
dot_logits = dot.data.clone()
assert not util.contains_inf(dot), f'dot contains inf (before norm) {dot.min()}, {dot.mean()}, {dot.max()}'
assert not util.contains_nan(dot), f'dot contains nan (before norm) {dot.min()}, {dot.mean()}, {dot.max()}'
if self.norm_method == 'softmax':
dot = sparse.logsoftmax(indices, weights * dot, size).exp()
else:
dot = sparse.simple_normalize(indices, weights * dot, size, method=self.norm_method)
# - dot now has row-wise self-attention probabilities
assert not util.contains_inf(dot), f'dot contains inf (after norm) {dot.min()}, {dot.mean()}, {dot.max()}'
try:
assert not util.contains_nan(dot), f'dot contains nan (after norm) {dot.min()}, {dot.mean()}, {dot.max()}'
except AssertionError:
print(dot.sum(dim=1))
print('\n\n\n')
for i in range(b*h):
print(f'*** {i}')
print(indices[i])
print(dot_logits[i])
print((weights * dot_logits)[i])
print('\n\n\n')
sys.exit()
# apply the self attention to the values
dot = sparse.logsoftmax(indices, dot, size).exp()
else:
dot = sparse.simple_normalize(indices, dot, size, method=self.norm_method)
# - dot now has row-wise self-attention probabilities
# assert not util.contains_inf(dot), f'dot contains inf (after softmax) {dot.min()}, {dot.mean()}, {dot.max()}'
# assert not util.contains_nan(dot), f'dot contains nan (after softmax) {dot.min()}, {dot.mean()}, {dot.max()}'
# apply the self attention to the values
out = sparse.batchmm(indices, dot, size=size, xmatrix=values)
# swap h, t back, unify heads
out = out.transpose(1, 2).contiguous().view(b, t, h * s)
out = self.unifyheads(out)
assert not util.contains_nan(out), f'output contains nan {out}, dot min/max: {dot.min()}/{dot.max()}'
return out
keys = keys / (e ** (1/4))
# get dot product of queries and keys
# - this will be a sparse matrix with the indices we've just computed, and values
# defined by the dot product
# select the queries
indflat = indices.view(b*h*t*vs, 2)
ar = torch.arange(b*h, dtype=torch.long, device=d(x))[:, None].expand(b*h, t*vs).contiguous().view(b*h*t*vs)
squeries = queries[ar, indflat[:, 0], :]
skeys = keys [ar, indflat[:, 1], :]
dot = torch.bmm(squeries[:, None, :], skeys[:, :, None]).view(b*h,t*vs)
#print(f'dot before {dot.min()}, {dot.mean()}, {dot.max()}')
assert not util.contains_nan(dot), f'dot contains nan (before softmax) {dot.min()}, {dot.mean()}, {dot.max()}'
#print(f'dot after {dot.min()}, {dot.mean()}, {dot.max()}\n')
dot = sparse.logsoftmax(indices, weights * dot, s).exp()
# - dot now has row-wise self-attention probabilities
assert not util.contains_nan(dot), f'dot contains nan (after softmax) {dot.min()}, {dot.mean()}, {dot.max()}'
# apply the self attention to the values
out = sparse.batchmm(indices, dot, size=(t, t), xmatrix=values)
# swap h, t back, unify heads
out = out.transpose(1, 2).contiguous().view(b, t, h * e)
out = self.unifyheads(out)
assert not util.contains_nan(out), f'output contains nan {out}'
else: # not train, just use the nearest indices
indices = means.round().long()
if self.use_cuda:
indices = indices.cuda()
# translate tensor indices to matrix indices so we can use matrix multiplication to perform the tensor contraction
mindices, flat_size = flatten_indices_mat(indices, input.size()[1:], self.out_size)
### Create the sparse weight tensor
x_flat = input.view(batchsize, -1)
# Prevent segfault
assert not contains_nan(values.data)
bm = self.bmult(flat_size[1], flat_size[0], mindices.size()[1], batchsize, self.use_cuda)
bfsize = Variable(flat_size * batchsize)
bfindices = mindices + bm
bfindices = bfindices.view(1, -1, 2).squeeze(0)
vindices = Variable(bfindices.t())
bfvalues = values.view(1, -1).squeeze(0)
bfx = x_flat.view(1, -1).squeeze(0)
spm = sparsemult(self.use_cuda)
bfy = spm(vindices, bfvalues, bfsize, bfx)
y_flat = bfy.unsqueeze(0).view(batchsize, -1)
assert indices.size() == (b, 1, vs, len(s))
assert weights.size() == (b, 1, vs)
indices, weights = indices.squeeze(1), weights.squeeze(1)
else:
vs = 1
indices = means.floor().to(torch.long).detach().squeeze(1)
# Select a single code from the latent space (per instance in batch).
# When sampling, this is a weighted sum, when not sampling, just one.
indices = indices.view(b*vs, len(s))
# checks to prevent segfaults
if util.contains_nan(indices):
print(params)
raise Exception('Indices contain NaN')
if indices[:, 0].max() >= s[0] or indices[:, 1].max() >= s[1]:
print(indices.max())
print(params)
raise Exception('Indices out of bounds')
if len(s) == 1:
code = self.latent[indices[:, 0], :]
elif len(s) == 2:
code = self.latent[indices[:, 0], indices[:, 1], :]
elif len(s) == 3:
code = self.latent[indices[:, 0], indices[:, 1], indices[:, 2], :]
def hyper(self, x):
b, t, e = x.size()
h, k, reg = self.heads, self.k, self.region
o = t if self.outputs < -1 else self.outputs
# Generate coords
coords = torch.arange(t, dtype=torch.float, device=d(x)) / t
coords = coords[None, :, None,].expand(b, t, 1)
input = torch.cat([x, coords], dim=2)
params = self.toparams(input) # (b, o, k*2)
assert not util.contains_nan(params), \
f'params contain NaN\n intput {input.min()} {input.max()} \n {list(self.toparams.parameters())}'
# Generate the logits that correspond to the horizontal coordinate of the current word
diags = torch.arange(t, dtype=torch.float, device=d(x))
if not self.clamp:
diags = util.inv(diags, mx=t)
diags = diags[None, :, None, None].expand(b, t, k, 1)
means = params[:, :, :k].view(b, t, k, 1)
sigmas = params[:, :, k:].view(b, t, k)
values = self.mvalues[None, None, :].expand(b, t, k)
means = diags - self.mmult * F.softplus(means)
s = (t,)
# translate tensor indices to matrix indices
# mindices, flat_size = flatten_indices(indices, input.size()[1:], self.out_shape, self.use_cuda)
mindices, flat_size = flatten_indices_mat(indices, input.size()[1:], self.out_size)
# NB: mindices is not an autograd Variable. The error-signal for the indices passes to the hypernetwork
# through 'values', which are a function of both the real_indices and the real_values.
### Create the sparse weight tensor
x_flat = input.view(batchsize, -1)
# Prevent segfault
try:
assert mindices.min() >= 0
assert not contains_nan(values.data)
except AssertionError as ae:
print('Nan in values or negative index in mindices.')
print('means', means)
print('sigmas', sigmas)
print('props', props)
print('values', values)
print('indices', indices)
print('mindices', mindices)
raise ae
# Then we flatten the batch dimension as well
bm = bmult(flat_size[1], flat_size[0], mindices.size()[1], batchsize, self.use_cuda)
bfsize = Variable(flat_size * batchsize)
bfindices = mindices + bm
source, target = source.cuda(), target.cuda()
source, target = Variable(source), Variable(target)
output = model(source)
loss = F.nll_loss(output.transpose(2, 1), target, reduction='none')
loss = loss.mean()
tbw.add_scalar('transformer/train-loss', float(loss.item()) * LOG2E, i * arg.batch_size)
assert loss.item() == loss.item(), f'Loss is nan {loss}'
loss.backward()
assert not util.contains_nan(model.parameters()), f'Parameters have become NaN {model.parameters()}'
if arg.cuda and (i == 0 or random.random() < 0.0005): # occasionally print peak GPU memory usage
print(f'\nPeak gpu memory use is {torch.cuda.max_memory_cached() / 1e9:.2} Gb')
# clip gradients
if arg.gradient_clipping is not None:
nn.utils.clip_grad_norm_(model.parameters(), arg.gradient_clipping)
opt.step()
if (arg.model.startswith('sparse') or arg.model == 'strided' or arg.model == 'mixed') and arg.plot_every > 0 and i % arg.plot_every == 0:
shape = (arg.context, arg.context)
means, sigmas, values = model.forward_for_plot(source)
for t, (m, s, v) in enumerate(zip(means, sigmas, values)):
b, c, k, r = m.size()
# weight the values by the proportions
weights = mvalues[:, :, None, :].expand_as(props)
# - add a dim for the MVNs
weights = props * weights
weights = weights.sum(dim=3) # - sum out the MVNs
out = selection[None, :, None, None].expand(b, tp, vs, 1) # output indices
indices = torch.cat([out, indices], dim=3)
assert indices.size() == (b, tp, vs, 2), f'{indices.size()}, {(b, tp, vs, 2)}'
assert weights.size() == (b, tp, vs), f'{weights.size()}, {(b, tp, vs)}'
assert not util.contains_inf(weights), f'weights contains inf (before norm) {weights.min()}, {weights.mean()}, {weights.max()}'
assert not util.contains_nan(weights), f'weights contains nan (before norm) {weights.min()}, {weights.mean()}, {weights.max()}'
# expand for heads, fold heads into batch
indices = indices[:, None, :, :, :].expand(b, h, tp, vs, 2).contiguous().view(b*h, tp*vs, 2)
weights = weights[:, None, :, :].expand(b, h, tp, vs).contiguous().view(b*h, tp*vs)
# compute keys, queries, values
keys = self.tokeys(x) # note: t not tp, we compute _all_ queries, keys and values
queries = self.toqueries(x)
values = self.tovalues(x)
# - fold heads into the batch dimension
keys = keys.transpose(1, 2).contiguous() .view(b * h, t, s)
queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
values = values.transpose(1, 2).contiguous() .view(b * h, t, s)
# -- We could actually select first, and _then_ transform to kqv's. May be better for very large contexts and
# small batches