Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
:param sigmas: (..., ) matrix of raw sigma values
:param size: Tuple describing the tensor dimensions.
:param min_sigma: Minimal sigma value.
:return:(..., rank) sigma values
"""
ssize = sigmas.size()
r = len(size)
# Scale to [0, 1]
sigmas = F.softplus(sigmas + SIGMA_BOOST) + min_sigma
# sigmas = sigmas[:, :, None].expand(b, k, r)
sigmas = sigmas.unsqueeze(-1).expand(*(ssize + (r, )))
# Compute upper bounds
s = torch.tensor(list(size), dtype=torch.float, device='cuda' if sigmas.is_cuda else 'cpu')
s = util.unsqueezen(s, len(sigmas.size()) - 1)
s = s.expand_as(sigmas)
return sigmas * s
def transform_means(means, size, method='sigmoid'):
"""
Transforms raw parameters for the index tuples (with values in (-inf, inf)) into parameters within the bound of the
dimensions of the tensor.
In the case of a templated sparse layer, these parameters and the corresponding size tuple deascribe only the learned
subtensor.
:param means: (..., rank) tensor of raw parameter values
:param size: Tuple describing the tensor dimensions.
:return: (..., rank)
"""
# Compute upper bounds
s = torch.tensor(list(size), dtype=torch.float, device=d(means)) - 1
s = util.unsqueezen(s, len(means.size()) - 1)
s = s.expand_as(means)
# Scale to [0, 1]
if method == 'modulo':
means = means.remainder(s)
return means
if method == 'clamp':
means = torch.max(means, torch.zeros(means.size(), device=d(means)))
means = torch.min(means, s)
return means
means = torch.sigmoid(means)
assert (global_ints >= bounds).sum() == 0, 'One of the global sampled indices is outside the tensor bounds'
"""
Sample uniformly from a small range around the given index tuple
"""
lsize = pref + (ladditional, rank)
local_ints = FT(*lsize)
local_ints.uniform_()
local_ints *= (1.0 - epsilon)
rngxp = util.unsqueezen(rng, len(lsize) - 1).expand_as(local_ints) # bounds of the tensor
rrng = FT(relative_range) # bounds of the range from which to sample
rrng = util.unsqueezen(rrng, len(lsize) - 1).expand_as(local_ints)
# print(means.size())
mns_expand = means.round().unsqueeze(-2).expand_as(local_ints)
# upper and lower bounds
lower = mns_expand - rrng * 0.5
upper = mns_expand + rrng * 0.5
# check for any ranges that are out of bounds
idxs = lower < 0.0
lower[idxs] = 0.0
idxs = upper > rngxp
lower[idxs] = rngxp[idxs] - rrng[idxs]
cached = local_ints.clone()
:param epsilon: The random bumbers are based on uniform samples in (0, 1-epsilon). Note that
in some cases epsilon needs to be relatively big (e.g. 10-5)
"""
b = means.size(0)
k, c, rank = means.size()[-3:]
pref = means.size()[:-1]
FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor
rng = FT(tuple(rng))
# - the tuple() is there in case a torch.Size() object is passed (which causes torch to
# interpret the argument as the size of the tensor rather than its content).
bounds = util.unsqueezen(rng, len(pref) + 1).long() # index bound with unsqueezed dims for broadcasting
if seed is not None:
torch.manual_seed(seed)
"""
Generate neighbor tuples
"""
if fm is None:
fm = floor_mask(rank, cuda)
size = pref + (2**rank, rank)
fm = util.unsqueezen(fm, len(size) - 2).expand(size)
neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous()
neighbor_ints[fm] = neighbor_ints[fm].floor()
# - the tuple() is there in case a torch.Size() object is passed (which causes torch to
# interpret the argument as the size of the tensor rather than its content).
bounds = util.unsqueezen(rng, len(pref) + 1).long() # index bound with unsqueezed dims for broadcasting
if seed is not None:
torch.manual_seed(seed)
"""
Generate neighbor tuples
"""
if fm is None:
fm = floor_mask(rank, cuda)
size = pref + (2**rank, rank)
fm = util.unsqueezen(fm, len(size) - 2).expand(size)
neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous()
neighbor_ints[fm] = neighbor_ints[fm].floor()
neighbor_ints[~fm] = neighbor_ints[~fm].ceil()
neighbor_ints = neighbor_ints.long()
assert (neighbor_ints >= bounds).sum() == 0, 'One of the neighbor indices is outside the tensor bounds'
"""
Sample uniformly from all integer tuples
"""
gsize = pref + (gadditional, rank)
global_ints = FT(*gsize)