How to use the torchkbnufft.math.conj_complex_mult function in torchkbnufft

To help you get started, we’ve selected a few torchkbnufft examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mmuckley / torchkbnufft / torchkbnufft / nufft / fft_functions.py View on Github external
else:
        x = x * torch.prod(grid_size)

    # scaling coefficient multiply
    while len(scaling_coef.shape) < len(x.shape):
        scaling_coef = scaling_coef.unsqueeze(0)

    # try to broadcast multiply - batch over coil if not enough memory
    raise_error = False
    try:
        x = conj_complex_mult(x, scaling_coef, dim=2)
    except RuntimeError as e:
        if 'out of memory' in str(e) and not raise_error:
            torch.cuda.empty_cache()
            for coilind in range(x.shape[1]):
                x[:, coilind, ...] = conj_complex_mult(
                    x[:, coilind:coilind + 1, ...], scaling_coef, dim=2)
            raise_error = True
        else:
            raise e
    except BaseException:
        raise e

    return x
github mmuckley / torchkbnufft / torchkbnufft / nufft / interp_functions.py View on Github external
# indexing locations
    gridind = (kofflist + Jval.unsqueeze(1)).to(dtype)
    distind = torch.round(
        (tm - gridind) * L.unsqueeze(1)).to(dtype=int_type)
    gridind = gridind.to(int_type)

    arr_ind = torch.zeros((M,), dtype=int_type, device=device)
    coef = torch.stack((
        torch.ones(M, dtype=dtype, device=device),
        torch.zeros(M, dtype=dtype, device=device)
    ))

    for d in range(ndims):  # spatial dimension
        if conjcoef:
            coef = conj_complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        else:
            coef = complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        arr_ind = arr_ind + torch.remainder(gridind[d, :], dims[d]).view(-1) * \
            torch.prod(dims[d + 1:])

    return coef, arr_ind
github mmuckley / torchkbnufft / torchkbnufft / mri / sensenufft_functions.py View on Github external
Returns:
        tensor: The images after forward and adjoint NUFFT of size
            (1, 2) + im_size.
    """
    # multiply sensitivities
    x = complex_mult(x, smap, dim=1)

    # Toeplitz NUFFT
    x = fft_filter(
        x.unsqueeze(0),
        kern.unsqueeze(0),
        norm=norm
    ).squeeze(0)

    # conjugate sum
    x = torch.sum(conj_complex_mult(x, smap, dim=1), dim=0, keepdim=True)

    return x
github mmuckley / torchkbnufft / torchkbnufft / nufft / fft_functions.py View on Github external
x = x[tuple(map(slice, crop_starts, crop_ends))]

    # scaling
    if norm == 'ortho':
        x = x * torch.sqrt(torch.prod(grid_size))
    else:
        x = x * torch.prod(grid_size)

    # scaling coefficient multiply
    while len(scaling_coef.shape) < len(x.shape):
        scaling_coef = scaling_coef.unsqueeze(0)

    # try to broadcast multiply - batch over coil if not enough memory
    raise_error = False
    try:
        x = conj_complex_mult(x, scaling_coef, dim=2)
    except RuntimeError as e:
        if 'out of memory' in str(e) and not raise_error:
            torch.cuda.empty_cache()
            for coilind in range(x.shape[1]):
                x[:, coilind, ...] = conj_complex_mult(
                    x[:, coilind:coilind + 1, ...], scaling_coef, dim=2)
            raise_error = True
        else:
            raise e
    except BaseException:
        raise e

    return x
github mmuckley / torchkbnufft / torchkbnufft / mri / sensenufft_functions.py View on Github external
interpob (dictionary): A NUFFT interpolation object.
        interp_mats (dictionary, default=None): A dictionary of sparse
            interpolation matrices. If not None, the NUFFT operation will use
            the matrices for interpolation.

    Returns:
        tensor: The images after adjoint NUFFT of size (nbatch, ncoil, 2) +
            im_size.
    """
    # adjoint nufft
    x = AdjKbNufftFunction.apply(y, om, interpob, interp_mats)

    # conjugate sum
    x = list(x)
    for i in range(len(x)):
        x[i] = torch.sum(conj_complex_mult(
            x[i], smap[i], dim=1), dim=0, keepdim=True)

    if isinstance(smap, torch.Tensor):
        x = torch.stack(x)

    return x