Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_Hstack(self):
shape = [5]
I = linop.Identity(shape)
x1 = util.randn(shape)
x2 = util.randn(shape)
x = util.vec([x1, x2])
A = linop.Hstack([I, I])
npt.assert_allclose(A(x), x1 + x2)
self.check_linop_linear(A)
self.check_linop_adjoint(A)
self.check_linop_pickleable(A)
shape = [5, 3]
I = linop.Identity(shape)
x1 = util.randn(shape)
x2 = util.randn(shape)
x = np.concatenate([x1, x2], axis=1)
A = linop.Hstack([I, I], axis=1)
npt.assert_allclose(A(x), x1 + x2)
self.check_linop_linear(A)
self.check_linop_adjoint(A)
self.check_linop_pickleable(A)
def test_Add(self):
shape = [5]
I = linop.Identity(shape)
A = linop.Add([I, I])
x = util.randn(shape)
npt.assert_allclose(A(x), 2 * x)
self.check_linop_linear(A)
self.check_linop_adjoint(A)
self.check_linop_pickleable(A)
def test_Diag(self):
shape = [5]
I = linop.Identity(shape)
x = util.randn([10])
A = linop.Diag([I, I])
npt.assert_allclose(A(x), x)
self.check_linop_linear(A)
self.check_linop_adjoint(A)
self.check_linop_pickleable(A)
shape = [5, 3]
I = linop.Identity(shape)
x = util.randn([5, 6])
A = linop.Diag([I, I], iaxis=1, oaxis=1)
npt.assert_allclose(A(x), x)
self.check_linop_linear(A)
self.check_linop_adjoint(A)
def test_Vstack(self):
shape = [5]
I = linop.Identity(shape)
x = util.randn(shape)
A = linop.Vstack([I, I])
npt.assert_allclose(A(x), util.vec([x, x]))
self.check_linop_linear(A)
self.check_linop_adjoint(A)
self.check_linop_pickleable(A)
shape = [5, 3]
I = linop.Identity(shape)
x = util.randn(shape)
A = linop.Vstack([I, I], axis=1)
npt.assert_allclose(A(x), np.concatenate([x, x], axis=1))
self.check_linop_linear(A)
self.check_linop_adjoint(A)
self.check_linop_pickleable(A)
def test_Hstack(self):
shape = [5]
I = linop.Identity(shape)
x1 = util.randn(shape)
x2 = util.randn(shape)
x = util.vec([x1, x2])
A = linop.Hstack([I, I])
npt.assert_allclose(A(x), x1 + x2)
self.check_linop_linear(A)
self.check_linop_adjoint(A)
self.check_linop_pickleable(A)
shape = [5, 3]
I = linop.Identity(shape)
x1 = util.randn(shape)
x2 = util.randn(shape)
x = np.concatenate([x1, x2], axis=1)
r -= self.y
with self.x_device:
gradf_x = self.A.H(r)
if self.lamda != 0:
if self.R is None:
util.axpy(gradf_x, self.lamda, x)
else:
util.axpy(gradf_x, self.lamda, self.R.H(self.R(x)))
if self.mu != 0:
util.axpy(gradf_x, self.mu, x - self.z)
return gradf_x
I = linop.Identity(self.x.shape)
AHA = self.A.H * self.A
if self.lamda != 0:
if self.R is None:
AHA += self.lamda * I
else:
AHA += self.lamda * self.R.H * self.R
if self.mu != 0:
AHA += self.mu * I
max_eig = MaxEig(AHA, dtype=self.x.dtype,
device=self.x_device, max_iter=self.max_power_iter,
show_pbar=self.show_pbar).run()
if max_eig == 0:
def _get_ConjugateGradient(self):
I = linop.Identity(self.x.shape)
AHA = self.A.H * self.A
AHy = self.A.H(self.y)
if self.lamda != 0:
if self.R is None:
AHA += self.lamda * I
else:
AHA += self.lamda * self.R.H * self.R
if self.mu != 0:
AHA += self.mu * I
util.axpy(AHy, self.mu, self.z)
self.alg = ConjugateGradient(
AHA, AHy, self.x, P=self.P, max_iter=self.max_iter)
def _update(self):
device = backend.get_device(self.y)
xp = device.xp
with device:
y_hat = self.y * xp.exp(1j * xp.angle(self.A * self.x))
I = sp.linop.Identity(self.A.ishape)
system = self.A.H * self.A + self.lamb * I
b = self.A.H * y_hat
alg_internal = ConjugateGradient(system, b, self.x, max_iter=5)
while not alg_internal.done():
alg_internal.update()
self.x = alg_internal.x
self.residual = xp.sum(xp.absolute(xp.absolute(self.A * self.x)
- self.y))
self.iter += 1
def FiniteDifference(ishape, axes=None):
"""Linear operator that computes finite difference gradient.
Args:
ishape (tuple of ints): Input shape.
axes (tuple or list): Axes to circularly shift. All axes are used if
None.
"""
I = Identity(ishape)
ndim = len(ishape)
axes = util._normalize_axes(axes, ndim)
linops = []
for i in axes:
D = I - Circshift(ishape, [1], axes=[i])
R = Reshape([1] + list(ishape), ishape)
linops.append(R * D)
G = Vstack(linops, axis=0)
return G