Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _scale_coord(coord, shape, oversamp):
ndim = coord.shape[-1]
device = backend.get_device(coord)
scale = backend.to_device([_get_ugly_number(oversamp * i) / i for i in shape[-ndim:]], device)
shift = backend.to_device([_get_ugly_number(oversamp * i) // 2 for i in shape[-ndim:]], device)
with device:
coord = scale * coord + shift
return coord
def _scale_coord(coord, shape, oversamp):
ndim = coord.shape[-1]
device = backend.get_device(coord)
scale = backend.to_device([_get_ugly_number(oversamp * i) / i for i in shape[-ndim:]], device)
shift = backend.to_device([_get_ugly_number(oversamp * i) // 2 for i in shape[-ndim:]], device)
with device:
coord = scale * coord + shift
return coord
def fwt(input, wave_name='db4', axes=None, level=None):
"""Forward wavelet transform.
Args:
input (array): Input array.
axes (None or tuple of int): Axes to perform wavelet transform.
wave_name (str): Wavelet name.
level (None or int): Number of wavelet levels.
"""
device = backend.get_device(input)
input = backend.to_device(input, backend.cpu_device)
zshape = [((i + 1) // 2) * 2 for i in input.shape]
zinput = util.resize(input, zshape)
coeffs = pywt.wavedecn(
zinput, wave_name, mode='zero', axes=axes, level=level)
output, _ = pywt.coeffs_to_array(coeffs, axes=axes)
output = backend.to_device(output, device)
return output
devices.append(backend.Device(0))
for device in devices:
xp = device.xp
with device:
for dtype in [np.float32, np.float64,
np.complex64, np.complex128]:
x = util.dirac([1, 3], device=device, dtype=dtype)
W = xp.ones([1, 3], dtype=dtype)
y = backend.to_device(conv.convolve(
x, W, mode=mode), backend.cpu_device)
npt.assert_allclose(y, [[0, 1, 1, 1, 0]], atol=1e-5)
x = util.dirac([1, 3], device=device, dtype=dtype)
W = xp.ones([1, 2], dtype=dtype)
y = backend.to_device(conv.convolve(
x, W, mode=mode), backend.cpu_device)
npt.assert_allclose(y, [[0, 1, 1, 0]], atol=1e-5)
x = util.dirac([1, 3], device=device, dtype=dtype)
W = xp.ones([2, 1, 3], dtype=dtype)
y = backend.to_device(
conv.convolve(
x,
W,
mode=mode,
output_multi_channel=True),
backend.cpu_device)
npt.assert_allclose(y, [[[0, 1, 1, 1, 0]],
[[0, 1, 1, 1, 0]]], atol=1e-5)
np.complex64, np.complex128]:
y = xp.ones([1, 1], dtype=dtype)
W = xp.ones([1, 3], dtype=dtype)
x = backend.to_device(conv.convolve_adjoint_input(
W, y, mode=mode), backend.cpu_device)
npt.assert_allclose(x, [[1, 1, 1]], atol=1e-5)
y = xp.ones([1, 2], dtype=dtype)
W = xp.ones([1, 2], dtype=dtype)
x = backend.to_device(conv.convolve_adjoint_input(
W, y, mode=mode), backend.cpu_device)
npt.assert_allclose(x, [[1, 2, 1]], atol=1e-5)
y = xp.ones([2, 1, 1], dtype=dtype)
W = xp.ones([2, 1, 3], dtype=dtype)
x = backend.to_device(
conv.convolve_adjoint_input(
W,
y,
mode=mode,
output_multi_channel=True),
backend.cpu_device)
npt.assert_allclose(x, [[2, 2, 2]], atol=1e-5)
def _apply(self, input):
device = backend.get_device(input)
xp = device.xp
with device:
if np.isscalar(self.mult):
if self.mult == 1:
return input
mult = self.mult
if self.conj:
mult = mult.conjugate()
else:
mult = backend.to_device(self.mult, backend.get_device(input))
if self.conj:
mult = xp.conj(mult)
return input * mult
output = output.swapaxes(a, -1)
os_shape[a], os_shape[-1] = os_shape[-1], os_shape[a]
# Apodize
output *= apod
# Oversampled FFT
output = util.resize(output, os_shape)
output = fft.fft(output, axes=[-1], norm=None)
output /= i**0.5
# Swap back
output = output.swapaxes(a, -1)
os_shape[a], os_shape[-1] = os_shape[-1], os_shape[a]
coord = _scale_coord(backend.to_device(coord, device), input.shape, oversamp)
table = backend.to_device(
_kb(np.arange(n, dtype=coord.dtype) / n, width, beta, dtype=coord.dtype), device)
output = interp.interp(output, width, table, coord)
return output