How to use the torchkbnufft.KbInterpBack function in torchkbnufft

To help you get started, we’ve selected a few torchkbnufft examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mmuckley / torchkbnufft / tests / test_kb_construction.py View on Github external
# test 2d scalar inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = 6
    table_oversamp = 2**10
    kbwidth = 2.34
    order = 0
    norm = 'None'

    ob = KbInterpForw(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
    ob = KbInterpBack(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)

    ob = KbNufft(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
    ob = AdjKbNufft(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)

    ob = MriSenseNufft(
        smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
    ob = AdjMriSenseNufft(
        smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
github mmuckley / torchkbnufft / tests / test_kb_construction.py View on Github external
# test 3d scalar inputs
    im_sz = (10, 256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (10, 512, 512)
    n_shift = (5, 128, 128)
    numpoints = 6
    table_oversamp = 2**10
    kbwidth = 2.34
    order = 0
    norm = 'None'

    ob = KbInterpForw(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
    ob = KbInterpBack(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)

    ob = KbNufft(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
    ob = AdjKbNufft(
        im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)

    ob = MriSenseNufft(
        smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
    ob = AdjMriSenseNufft(
        smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
        table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
github mmuckley / torchkbnufft / tests / test_pytorch_grad_adj_matching.py View on Github external
1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_2d['y']
    ktraj = params_2d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj)

        assert torch.norm(x_grad-x_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_sparse_adjoints.py View on Github external
1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_2d['y']
    ktraj = params_2d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x_forw = kbinterp_ob(x, ktraj, interp_mats)
        y_back = adjkbinterp_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)
github mmuckley / torchkbnufft / tests / test_pytorch_grad_adj_matching.py View on Github external
1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        y.requires_grad = True
        x = adjkbinterp_ob.forward(y, ktraj)

        ((x ** 2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbinterp_ob.forward(x.clone().detach(), ktraj)

        assert torch.norm(y_grad-y_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj, interp_mats)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()
github mmuckley / torchkbnufft / tests / test_kb_construction.py View on Github external
kbwidths = [2.34, 5]
    orders = [0, 2]

    for kbwidth in kbwidths:
        for order in orders:
            for im_sz in im_szs:
                smap = torch.randn(*((1,) + im_sz))

                base_table = AdjKbNufft(
                    im_sz, order=order, kbwidth=kbwidth).table

                cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpBack(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpForw(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = MriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = AdjMriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_2d['y']
    ktraj = params_2d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj, interp_mats)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        y.requires_grad = True
        x = adjkbinterp_ob.forward(y, ktraj, interp_mats)

        ((x ** 2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()
github mmuckley / torchkbnufft / tests / test_sparse_adjoints.py View on Github external
1j*np.random.normal(size=(batch_size, 1) + grid_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2))
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x_forw = kbinterp_ob(x, ktraj, interp_mats)
        y_back = adjkbinterp_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)