How to use the torchkbnufft.MriSenseNufft function in torchkbnufft

To help you get started, we’ve selected a few torchkbnufft examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }
github mmuckley / torchkbnufft / tests / test_kb_construction.py View on Github external
ob = KbInterpForw(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
    ob = KbInterpBack(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)

    ob = KbNufft(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
    ob = AdjKbNufft(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)

    ob = MriSenseNufft(
        smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
    ob = AdjMriSenseNufft(
        smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
github mmuckley / torchkbnufft / tests / test_pytorch_grad_adj_matching.py View on Github external
norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']
    smap = params_3d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        x.requires_grad = True
        y = sensenufft_ob.forward(x, ktraj)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()
github mmuckley / torchkbnufft / tests / test_sparse_adjoints.py View on Github external
norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints,
            coilpack=True
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints,
            coilpack=True
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']
    smap = params_3d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }
github mmuckley / torchkbnufft / tests / test_pytorch_grad_adj_matching.py View on Github external
norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        x.requires_grad = True
        y = sensenufft_ob.forward(x, ktraj)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()
github mmuckley / torchkbnufft / tests / test_adjoints.py View on Github external
norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']
    smap = params_3d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        x_forw = sensenufft_ob(x, ktraj)
        y_back = adjsensenufft_ob(y, ktraj)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)
github mmuckley / torchkbnufft / tests / test_pytorch_grad_adj_matching.py View on Github external
norm_tol = testing_tol

    im_size = params_2d['im_size']
    numpoints = params_2d['numpoints']

    x = params_2d['x']
    y = params_2d['y']
    ktraj = params_2d['ktraj']
    smap = params_2d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        y.requires_grad = True
        x = adjsensenufft_ob.forward(y, ktraj)

        ((x ** 2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()
github mmuckley / torchkbnufft / tests / test_pytorch_grad_adj_matching.py View on Github external
norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']
    smap = params_3d['smap']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        y.requires_grad = True
        x = adjsensenufft_ob.forward(y, ktraj)

        ((x ** 2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()
github mmuckley / torchkbnufft / profile_torchkbnufft.py View on Github external
num_nuffts = 20
        else:
            num_nuffts = 5
    else:
        dtype = torch.float
        if use_toep:
            num_nuffts = 50
        else:
            num_nuffts = 20
    cpudevice = torch.device('cpu')

    image = image.to(dtype=dtype)
    ktraj = ktraj.to(dtype=dtype)
    smap = smap.to(dtype=dtype)

    kbsense_ob = MriSenseNufft(smap=smap, im_size=im_size).to(
        dtype=dtype, device=device)
    adjkbsense_ob = AdjMriSenseNufft(
        smap=smap, im_size=im_size).to(dtype=dtype, device=device)

    adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)

    # precompute toeplitz kernel if using toeplitz
    if use_toep:
        print('using toeplitz for forward/backward')
        kern = calc_toep_kernel(adjkbsense_ob, ktraj)
        toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        print('using sparse interpolation matrices')
        real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)