Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
torch.floor(tm - numpoints.unsqueeze(1) / 2.0).to(dtype=torch.long)
# initialize output array
griddat = torch.zeros(size=(kdat.shape[0], 2, torch.prod(dims)),
dtype=dtype, device=device)
# loop over offsets and take advantage of numpy broadcasting
for Jind in range(Jlist.shape[1]):
coef, arr_ind = calc_coef_and_indices(
tm, kofflist, Jlist[:, Jind], table, centers, L, dims, conjcoef=True)
# the following code takes ordered data and scatters it on to an image grid
# profiling for a 2D problem showed drastic differences in performances
# for these two implementations on cpu/gpu, but they do the same thing
if device == torch.device('cpu'):
tmp = complex_mult(coef.unsqueeze(0), kdat, dim=1)
for bind in range(griddat.shape[0]):
for riind in range(griddat.shape[1]):
griddat[bind, riind].index_put_(
tuple(arr_ind.unsqueeze(0)),
tmp[bind, riind],
accumulate=True
)
else:
griddat.index_add_(
2,
arr_ind,
complex_mult(coef.unsqueeze(0), kdat, dim=1)
)
return griddat
arr_ind = torch.zeros((M,), dtype=int_type, device=device)
coef = torch.stack((
torch.ones(M, dtype=dtype, device=device),
torch.zeros(M, dtype=dtype, device=device)
))
for d in range(ndims): # spatial dimension
if conjcoef:
coef = conj_complex_mult(
coef,
table[d][:, distind[d, :] + centers[d]],
dim=0
)
else:
coef = complex_mult(
coef,
table[d][:, distind[d, :] + centers[d]],
dim=0
)
arr_ind = arr_ind + torch.remainder(gridind[d, :], dims[d]).view(-1) * \
torch.prod(dims[d + 1:])
return coef, arr_ind
Args:
x (tensor): The input images of size (1, 2) + im_size.
smap (tensor): The sensitivity maps of size (ncoil, 2) +
im_size.
kern (tensor): Embedded Toeplitz NUFFT kernel of size
(ncoil, 2) + im_size*2.
norm (str, default=None): If 'ortho', use orthogonal FFTs for Toeplitz
NUFFT filter.
Returns:
tensor: The images after forward and adjoint NUFFT of size
(1, 2) + im_size.
"""
# multiply sensitivities
x = complex_mult(x, smap, dim=1)
# Toeplitz NUFFT
x = fft_filter(
x.unsqueeze(0),
kern.unsqueeze(0),
norm=norm
).squeeze(0)
# conjugate sum
x = torch.sum(conj_complex_mult(x, smap, dim=1), dim=0, keepdim=True)
return x
Returns:
tensor: Output off-grid k-space data of dimensions (nbatch, ncoil, 2,
klength).
"""
if isinstance(smap, torch.Tensor):
dtype = smap.dtype
device = smap.device
mult_x = torch.zeros(smap.shape, dtype=dtype, device=device)
else:
mult_x = [None] * len(smap)
# handle batch dimension
for i, im in enumerate(x):
# multiply sensitivities
mult_x[i] = complex_mult(im, smap[i], dim=1)
y = KbNufftFunction.apply(mult_x, om, interpob, interp_mats)
return y
if interp_mats is None:
params['dims'] = torch.tensor(
x[b].shape[2:], dtype=torch.long, device=device)
y.append(run_interp(x[b].view((x.shape[1], 2, -1)), tm[b], params))
else:
y.append(
run_mat_interp(
x[b].view((x.shape[1], 2, -1)),
interp_mats['real_interp_mats'][b],
interp_mats['imag_interp_mats'][b],
)
)
# phase for fftshift
y[-1] = complex_mult(
y[-1],
imag_exp(torch.mv(torch.transpose(
om[b], 1, 0), n_shift)).unsqueeze(0),
dim=1
)
y = torch.stack(y)
return y
# profiling for a 2D problem showed drastic differences in performances
# for these two implementations on cpu/gpu, but they do the same thing
if device == torch.device('cpu'):
tmp = complex_mult(coef.unsqueeze(0), kdat, dim=1)
for bind in range(griddat.shape[0]):
for riind in range(griddat.shape[1]):
griddat[bind, riind].index_put_(
tuple(arr_ind.unsqueeze(0)),
tmp[bind, riind],
accumulate=True
)
else:
griddat.index_add_(
2,
arr_ind,
complex_mult(coef.unsqueeze(0), kdat, dim=1)
)
return griddat
inv_permute_dims.append(2 + i)
permute_dims.append(2)
pad_sizes = tuple(pad_sizes)
permute_dims = tuple(permute_dims)
inv_permute_dims = tuple(inv_permute_dims)
# zero pad and fft
x = F.pad(x, pad_sizes)
x = x.permute(permute_dims)
x = torch.fft(x, grid_size.numel())
if norm == 'ortho':
x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
x = x.permute(inv_permute_dims)
# apply the filter
x = complex_mult(x, kern, dim=2)
# inverse fft
x = x.permute(permute_dims)
x = torch.ifft(x, grid_size.numel())
x = x.permute(inv_permute_dims)
# crop to input size
crop_starts = tuple(np.array(x.shape).astype(np.int) * 0)
crop_ends = [x.shape[0], x.shape[1], x.shape[2]]
for dim in im_size:
crop_ends.append(int(dim))
x = x[tuple(map(slice, crop_starts, crop_ends))]
# scaling, assume user handled adjoint scaling with their kernel
if norm == 'ortho':
x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
dtype=dtype, device=device)
# loop over offsets and take advantage of broadcasting
for Jind in range(Jlist.shape[1]):
coef, arr_ind = calc_coef_and_indices(
tm, kofflist, Jlist[:, Jind], table, centers, L, dims)
# unsqueeze coil and real/imag dimensions for on-grid indices
arr_ind = arr_ind.unsqueeze(0).unsqueeze(0).expand(
kdat.shape[0],
kdat.shape[1],
-1
)
# gather and multiply coefficients
kdat += complex_mult(
coef.unsqueeze(0),
torch.gather(griddat, 2, arr_ind),
dim=1
)
return kdat