How to use the torchkbnufft.math.complex_mult function in torchkbnufft

To help you get started, we’ve selected a few torchkbnufft examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mmuckley / torchkbnufft / torchkbnufft / nufft / interp_functions.py View on Github external
torch.floor(tm - numpoints.unsqueeze(1) / 2.0).to(dtype=torch.long)

    # initialize output array
    griddat = torch.zeros(size=(kdat.shape[0], 2, torch.prod(dims)),
                          dtype=dtype, device=device)

    # loop over offsets and take advantage of numpy broadcasting
    for Jind in range(Jlist.shape[1]):
        coef, arr_ind = calc_coef_and_indices(
            tm, kofflist, Jlist[:, Jind], table, centers, L, dims, conjcoef=True)

        # the following code takes ordered data and scatters it on to an image grid
        # profiling for a 2D problem showed drastic differences in performances
        # for these two implementations on cpu/gpu, but they do the same thing
        if device == torch.device('cpu'):
            tmp = complex_mult(coef.unsqueeze(0), kdat, dim=1)
            for bind in range(griddat.shape[0]):
                for riind in range(griddat.shape[1]):
                    griddat[bind, riind].index_put_(
                        tuple(arr_ind.unsqueeze(0)),
                        tmp[bind, riind],
                        accumulate=True
                    )
        else:
            griddat.index_add_(
                2,
                arr_ind,
                complex_mult(coef.unsqueeze(0), kdat, dim=1)
            )

    return griddat
github mmuckley / torchkbnufft / torchkbnufft / nufft / interp_functions.py View on Github external
arr_ind = torch.zeros((M,), dtype=int_type, device=device)
    coef = torch.stack((
        torch.ones(M, dtype=dtype, device=device),
        torch.zeros(M, dtype=dtype, device=device)
    ))

    for d in range(ndims):  # spatial dimension
        if conjcoef:
            coef = conj_complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        else:
            coef = complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        arr_ind = arr_ind + torch.remainder(gridind[d, :], dims[d]).view(-1) * \
            torch.prod(dims[d + 1:])

    return coef, arr_ind
github mmuckley / torchkbnufft / torchkbnufft / mri / sensenufft_functions.py View on Github external
Args:
        x (tensor): The input images of size (1, 2) + im_size.
        smap (tensor): The sensitivity maps of size (ncoil, 2) +
            im_size.
        kern (tensor): Embedded Toeplitz NUFFT kernel of size
            (ncoil, 2) + im_size*2.
        norm (str, default=None): If 'ortho', use orthogonal FFTs for Toeplitz
            NUFFT filter.

    Returns:
        tensor: The images after forward and adjoint NUFFT of size
            (1, 2) + im_size.
    """
    # multiply sensitivities
    x = complex_mult(x, smap, dim=1)

    # Toeplitz NUFFT
    x = fft_filter(
        x.unsqueeze(0),
        kern.unsqueeze(0),
        norm=norm
    ).squeeze(0)

    # conjugate sum
    x = torch.sum(conj_complex_mult(x, smap, dim=1), dim=0, keepdim=True)

    return x
github mmuckley / torchkbnufft / torchkbnufft / mri / sensenufft_functions.py View on Github external
Returns:
        tensor: Output off-grid k-space data of dimensions (nbatch, ncoil, 2,
            klength).
    """
    if isinstance(smap, torch.Tensor):
        dtype = smap.dtype
        device = smap.device
        mult_x = torch.zeros(smap.shape, dtype=dtype, device=device)
    else:
        mult_x = [None] * len(smap)

    # handle batch dimension
    for i, im in enumerate(x):
        # multiply sensitivities
        mult_x[i] = complex_mult(im, smap[i], dim=1)

    y = KbNufftFunction.apply(mult_x, om, interpob, interp_mats)

    return y
github mmuckley / torchkbnufft / torchkbnufft / nufft / interp_functions.py View on Github external
if interp_mats is None:
            params['dims'] = torch.tensor(
                x[b].shape[2:], dtype=torch.long, device=device)

            y.append(run_interp(x[b].view((x.shape[1], 2, -1)), tm[b], params))
        else:
            y.append(
                run_mat_interp(
                    x[b].view((x.shape[1], 2, -1)),
                    interp_mats['real_interp_mats'][b],
                    interp_mats['imag_interp_mats'][b],
                )
            )

        # phase for fftshift
        y[-1] = complex_mult(
            y[-1],
            imag_exp(torch.mv(torch.transpose(
                om[b], 1, 0), n_shift)).unsqueeze(0),
            dim=1
        )

    y = torch.stack(y)

    return y
github mmuckley / torchkbnufft / torchkbnufft / nufft / interp_functions.py View on Github external
# profiling for a 2D problem showed drastic differences in performances
        # for these two implementations on cpu/gpu, but they do the same thing
        if device == torch.device('cpu'):
            tmp = complex_mult(coef.unsqueeze(0), kdat, dim=1)
            for bind in range(griddat.shape[0]):
                for riind in range(griddat.shape[1]):
                    griddat[bind, riind].index_put_(
                        tuple(arr_ind.unsqueeze(0)),
                        tmp[bind, riind],
                        accumulate=True
                    )
        else:
            griddat.index_add_(
                2,
                arr_ind,
                complex_mult(coef.unsqueeze(0), kdat, dim=1)
            )

    return griddat
github mmuckley / torchkbnufft / torchkbnufft / nufft / fft_functions.py View on Github external
inv_permute_dims.append(2 + i)
    permute_dims.append(2)
    pad_sizes = tuple(pad_sizes)
    permute_dims = tuple(permute_dims)
    inv_permute_dims = tuple(inv_permute_dims)

    # zero pad and fft
    x = F.pad(x, pad_sizes)
    x = x.permute(permute_dims)
    x = torch.fft(x, grid_size.numel())
    if norm == 'ortho':
        x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
    x = x.permute(inv_permute_dims)

    # apply the filter
    x = complex_mult(x, kern, dim=2)

    # inverse fft
    x = x.permute(permute_dims)
    x = torch.ifft(x, grid_size.numel())
    x = x.permute(inv_permute_dims)

    # crop to input size
    crop_starts = tuple(np.array(x.shape).astype(np.int) * 0)
    crop_ends = [x.shape[0], x.shape[1], x.shape[2]]
    for dim in im_size:
        crop_ends.append(int(dim))
    x = x[tuple(map(slice, crop_starts, crop_ends))]

    # scaling, assume user handled adjoint scaling with their kernel
    if norm == 'ortho':
        x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
github mmuckley / torchkbnufft / torchkbnufft / nufft / interp_functions.py View on Github external
dtype=dtype, device=device)

    # loop over offsets and take advantage of broadcasting
    for Jind in range(Jlist.shape[1]):
        coef, arr_ind = calc_coef_and_indices(
            tm, kofflist, Jlist[:, Jind], table, centers, L, dims)

        # unsqueeze coil and real/imag dimensions for on-grid indices
        arr_ind = arr_ind.unsqueeze(0).unsqueeze(0).expand(
            kdat.shape[0],
            kdat.shape[1],
            -1
        )

        # gather and multiply coefficients
        kdat += complex_mult(
            coef.unsqueeze(0),
            torch.gather(griddat, 2, arr_ind),
            dim=1
        )

    return kdat