How to use the sigpy.linop.Linop function in sigpy

To help you get started, we’ve selected a few sigpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mikgroup / sigpy / sigpy / linop.py View on Github external
return Add([self, input])
        else:
            raise NotImplementedError

    def __neg__(self):
        return -1 * self

    def __sub__(self, input):
        return self.__add__(-input)

    def __repr__(self):
        return '<{oshape}x{ishape}> {repr_str} Linop>'.format(
            oshape=self.oshape, ishape=self.ishape, repr_str=self.repr_str)


class Identity(Linop):
    """Identity linear operator.

    Returns input directly.

    Args:
        shape (tuple of ints): Input shape

    """

    def __init__(self, shape):
        super().__init__(shape, shape)

    def _apply(self, input):
        return input

    def _adjoint_linop(self):
github mikgroup / sigpy / sigpy / linop.py View on Github external
def _apply(self, input):
        device = backend.get_device(input)
        data = backend.to_device(self.data, device)
        with device:
            return conv.convolve(data, input,
                                 mode=self.mode, strides=self.strides,
                                 multi_channel=self.multi_channel)

    def _adjoint_linop(self):
        return ConvolveFilterAdjoint(
            self.ishape, self.data,
            mode=self.mode, strides=self.strides,
            multi_channel=self.multi_channel)


class ConvolveFilterAdjoint(Linop):
    r"""Adjoint convolution operator for filter arrays.

    Args:
        filt_shape (tuple of ints): filter array shape:
            :math:`[n_1, \ldots, n_D]` if multi_channel is False
            :math:`[c_o, c_i, n_1, \ldots, n_D]` otherwise.
        data (array): data array of shape:
            :math:`[\ldots, m_1, \ldots, m_D]` if multi_channel is False,
            :math:`[\ldots, c_i, m_1, \ldots, m_D]` otherwise.
        mode (str): {'full', 'valid'}.
        strides (None or tuple of ints): convolution strides of length D.
        multi_channel (bool): specify if input/output has multiple channels.

    """
    def __init__(self, filt_shape, data,
                 mode='full', strides=None,
github mikgroup / sigpy / sigpy / linop.py View on Github external
raise Exception(
                'Shapes must have the same lengths to concatenate.')

        for i in range(ndim):
            if i == axis:
                ishape[i] += shape[i]
                indices.append(idx)
                idx += shape[i]
            elif shape[i] != ishape[i]:
                raise RuntimeError(
                    'Shapes not along axis must be the same to concatenate.')

    return ishape, indices


class Hstack(Linop):
    """Horizontally stack linear operators.

    Creates a Linop that splits the input, applies Linops independently,
    and sums outputs.
    In matrix form, this is equivalant to given matrices {A1, ..., An},
    returns [A1, ..., An].

    Input and output devices must be the same.

    Args:
        linops (list of Linops): list of linops with the same output shape.
        axis (int or None): If None, inputs are vectorized and concatenated.
            Otherwise, inputs are stacked along axis.

    """
github mikgroup / sigpy / sigpy / linop.py View on Github external
super().__init__(oshape, ishape)

    def _apply(self, input):
        device = backend.get_device(input)
        with device:
            return fourier.nufft_adjoint(
                input, self.coord, self.oshape,
                oversamp=self.oversamp, width=self.width)

    def _adjoint_linop(self):
        return NUFFT(self.oshape, self.coord,
                     oversamp=self.oversamp, width=self.width)


class ConvolveData(Linop):
    r"""Convolution operator for data arrays.

    Args:
        data_shape (tuple of ints): data array shape:
            :math:`[\ldots, m_1, \ldots, m_D]` if multi_channel is False,
            :math:`[\ldots, c_i, m_1, \ldots, m_D]` otherwise.
        filt (array): filter array of shape:
            :math:`[n_1, \ldots, n_D]` if multi_channel is False
            :math:`[c_o, c_i, n_1, \ldots, n_D]` otherwise.
        mode (str): {'full', 'valid'}.
        strides (None or tuple of ints): convolution strides of length D.
        multi_channel (bool): specify if input/output has multiple channels.

    """
    def __init__(self, data_shape, filt, mode='full', strides=None,
                 multi_channel=False):
github mikgroup / sigpy / sigpy / linop.py View on Github external
def _apply(self, input):
        device = backend.get_device(input)
        with device:
            coord = backend.to_device(self.coord, device)
            return interp.interpolate(input, coord,
                                      kernel=self.kernel,
                                      width=self.width, param=self.param)

    def _adjoint_linop(self):
        return Gridding(
            self.ishape, self.coord,
            kernel=self.kernel, width=self.width, param=self.param)


class Gridding(Linop):
    """Gridding linear operator.

    Args:
        oshape (tuple of ints): Output shape = batch_shape + pts_shape
        ishape (tuple of ints): Input shape = batch_shape + grd_shape
        coord (array): Coordinates, values from - nx / 2 to nx / 2 - 1.
                ndim can only be 1, 2 or 3. of shape pts_shape + [ndim]
        width (float): Width of interp. kernel in grid size.
        kernel (str): Interpolation kernel, {'spline', 'kaiser_bessel'}.
        param (float): Kernel parameter.

    See Also:
        :func:`sigpy.gridding`

    """
github mikgroup / sigpy / sigpy / linop.py View on Github external
"""
    I = Identity(ishape)
    ndim = len(ishape)
    axes = util._normalize_axes(axes, ndim)
    linops = []
    for i in axes:
        D = I - Circshift(ishape, [1], axes=[i])
        R = Reshape([1] + list(ishape), ishape)
        linops.append(R * D)

    G = Vstack(linops, axis=0)

    return G


class NUFFT(Linop):
    """NUFFT linear operator.

    Args:
        ishape (tuple of int): Input shape.
        coord (array): Coordinates, with values [-ishape / 2, ishape / 2]
        oversamp (float): Oversampling factor.
        width (float): Kernel width.
        n (int): Kernel sampling number.

    """
    def __init__(self, ishape, coord, oversamp=1.25, width=4):
        self.coord = coord
        self.oversamp = oversamp
        self.width = width

        ndim = coord.shape[-1]
github mikgroup / sigpy / sigpy / linop.py View on Github external
if not (i == m or i == 1 or m == 1):
            raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
                ishape=ishape, mshape=mshape))

        oshape.append(max(i, m))

    if ishape_exp[-1] != mshape_exp[-2]:
        raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
            ishape=ishape, mshape=mshape))

    oshape += [ishape_exp[-2], mshape_exp[-1]]

    return oshape


class RightMatMul(Linop):
    """Matrix multiplication on the right.

    Args:
        ishape (tuple of ints): Input shape.
            It must be able to broadcast with mat.shape.
        mat (array): Matrix of shape [..., m, n]
        adjoint (bool): Toggle adjoint.
            If True, performs conj(mat).swapaxes(-1, -2)
            before performing matrix multiplication.

    """

    def __init__(self, ishape, mat, adjoint=False):
        self.mat = mat
        self.adjoint = adjoint
github mikgroup / sigpy / sigpy / linop.py View on Github external
self.expanded_ishape.append(oshape[d])
                self.reps.append(1)

        super().__init__(oshape, ishape)

    def _apply(self, input):
        device = backend.get_device(input)
        xp = device.xp
        with device:
            return xp.tile(input.reshape(self.expanded_ishape), self.reps)

    def _adjoint_linop(self):
        return Sum(self.oshape, self.axes)


class ArrayToBlocks(Linop):
    """Extract blocks from an array in a sliding window manner.

    Args:
        ishape (array): input array of shape [..., N_1, ..., N_D]
        blk_shape (tuple): block shape of length D, with D <= 4.
        blk_strides (tuple): block strides of length D.

    See Also:
        :func:`sigpy.block.array_to_blocks`

    """

    def __init__(self, ishape, blk_shape, blk_strides):
        self.blk_shape = blk_shape
        self.blk_strides = blk_strides
        D = len(blk_shape)
github mikgroup / sigpy / sigpy / linop.py View on Github external
self.ishift = ishift
        self.oshift = oshift

        super().__init__(oshape, ishape)

    def _apply(self, input):
        with backend.get_device(input):
            return util.resize(input, self.oshape,
                               ishift=self.ishift, oshift=self.oshift)

    def _adjoint_linop(self):
        return Resize(self.ishape, self.oshape,
                      ishift=self.oshift, oshift=self.ishift)


class Flip(Linop):
    """Flip linear operator.

    Args:
        shape (tuple of int): Input shape
    """

    def __init__(self, shape, axes=None):
        self.axes = axes

        super().__init__(shape, shape)

    def _apply(self, input):
        device = backend.get_device(input)
        with device:
            return util.flip(input, self.axes)
github mikgroup / sigpy / sigpy / linop.py View on Github external
def _apply(self, input):
        return input.transpose(self.axes)

    def _adjoint_linop(self):

        if self.axes is None:
            iaxes = None
            oshape = self.ishape[::-1]
        else:
            iaxes = np.argsort(self.axes)
            oshape = [self.ishape[a] for a in self.axes]

        return Transpose(oshape, axes=iaxes)


class FFT(Linop):
    """FFT linear operator.

    Args:
        ishape (tuple of int): Input shape
        axes (None or tuple of int): Axes to perform FFT.
            If None, applies on all axes.
        center (bool): Toggle center FFT.

    """

    def __init__(self, shape, axes=None, center=True):
        self.axes = axes
        self.center = center

        super().__init__(shape, shape)