How to use the sigpy.backend.get_array_module function in sigpy

To help you get started, we’ve selected a few sigpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mikgroup / sigpy / sigpy / fourier.py View on Github external
def _apodize(input, ndim, oversamp, width, beta):
    xp = backend.get_array_module(input)
    output = input
    for a in range(-ndim, 0):
        i = output.shape[a]
        os_i = ceil(oversamp * i)
        idx = xp.arange(i, dtype=output.dtype)

        # Calculate apodization
        apod = (beta**2 - (np.pi * width * (idx - i // 2) / os_i)**2)**0.5
        apod /= xp.sinh(apod)
        output *= apod.reshape([i] + [1] * (-a - 1))

    return output
github mikgroup / sigpy / sigpy / fourier.py View on Github external
def _fftc(input, oshape=None, axes=None, norm='ortho'):

    ndim = input.ndim
    axes = util._normalize_axes(axes, ndim)
    xp = backend.get_array_module(input)

    if oshape is None:
        oshape = input.shape

    tmp = util.resize(input, oshape)
    tmp = xp.fft.ifftshift(tmp, axes=axes)
    tmp = xp.fft.fftn(tmp, axes=axes, norm=norm)
    output = xp.fft.fftshift(tmp, axes=axes)
    return output
github mikgroup / sigpy / sigpy / util.py View on Github external
def rss(input, axes=(0, )):
    """Root sum of squares.

    Args:
        input (array): Input array.
        axes (None or tuple of ints): Axes to perform operation.

    Returns:
        array: Result.
    """
    xp = backend.get_array_module(input)
    return xp.sum(xp.abs(input)**2, axis=axes)**0.5
github mikgroup / sigpy / sigpy / conv.py View on Github external
def _convolve_cuda(data, filt,
                       mode='full', strides=None,
                       multi_channel=False):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt.shape,
            mode, strides, multi_channel)
        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        data = data.reshape((B, c_i) + m)
        filt = filt.reshape((c_o, c_i) + n)
        output = xp.empty((B, c_o) + p, dtype=data.dtype)
github mikgroup / sigpy / sigpy / util.py View on Github external
Args:
        input (array): Input array.
        factors (tuple of ints): Upsampling factors.
        shifts (None or tuple of ints): Shifts.

    Returns:
        array: Result.
    """

    if shift is None:
        shift = [0] * len(factors)

    slc = tuple(slice(s, None, f) for s, f in zip(shift, factors))

    xp = backend.get_array_module(input)
    output = xp.zeros(oshape, dtype=input.dtype)
    output[slc] = input

    return output
github mikgroup / sigpy / sigpy / conv.py View on Github external
def _convolve_filter_adjoint_cuda(output, data, filt_shape,
                                      mode='full', strides=None,
                                      multi_channel=False):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt_shape,
            mode, strides, multi_channel)
        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        data = data.reshape((B, c_i) + m)
        output = output.reshape((B, c_o) + p)
github mikgroup / sigpy / sigpy / thresh.py View on Github external
r"""Soft threshold.

    Performs:

    .. math::
        (| x | - \lambda)_+  \text{sgn}(x)

    Args:
        lamda (float, or array): Threshold parameter.
        input (array)

    Returns:
        array: soft-thresholded result.

    """
    xp = backend.get_array_module(input)
    if xp == np:
        return _soft_thresh(lamda, input)
    else:  # pragma: no cover
        return _soft_thresh_cuda(lamda, input)