Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_matmul_vec(self):
# Forward
res = NonLazyVariable(self.mat_var).matmul(self.vec_var)
actual = self.mat_var_clone.matmul(self.vec_var_clone)
self.assertTrue(approx_equal(res, actual))
# Backward
grad_output = torch.Tensor(3)
res.backward(gradient=grad_output)
actual.backward(gradient=grad_output)
self.assertTrue(approx_equal(self.mat_var_clone.grad.data, self.mat_var.grad.data))
self.assertTrue(approx_equal(self.vec_var_clone.grad.data, self.vec_var.grad.data))
def test_log_det_only(self):
# Forward pass
with gpytorch.settings.num_trace_samples(1000):
res = NonLazyVariable(self.mats_var).log_det()
for i in range(self.mats_var.size(0)):
self.assert_scalar_almost_equal(res.data[i], self.log_dets[i], places=1)
# Backward
grad_output = torch.Tensor([3, 4])
actual_mat_grad = torch.cat([
self.mats_var_clone[0].data.inverse().mul(grad_output[0]).unsqueeze(0),
self.mats_var_clone[1].data.inverse().mul(grad_output[1]).unsqueeze(0),
])
res.backward(gradient=grad_output)
self.assertTrue(approx_equal(actual_mat_grad, self.mats_var.grad.data, epsilon=1e-1))
def test_left_t_interp_on_a_vector(self):
vector = torch.randn(9)
res = left_t_interp(self.interp_indices, self.interp_values, vector, 6).data
actual = torch.matmul(self.interp_matrix.transpose(-1, -2), vector)
self.assertTrue(approx_equal(res, actual))
def test_matmul_multiple_vecs(self):
# Forward
res = NonLazyVariable(self.mat_var).matmul(self.vecs_var)
actual = self.mat_var_clone.matmul(self.vecs_var_clone)
self.assertTrue(approx_equal(res, actual))
# Backward
grad_output = torch.Tensor(3, 4)
res.backward(gradient=grad_output)
actual.backward(gradient=grad_output)
self.assertTrue(approx_equal(self.mat_var_clone.grad.data, self.mat_var.grad.data))
self.assertTrue(approx_equal(self.vecs_var_clone.grad.data, self.vecs_var.grad.data))
def test_inv_quad_only_vector(self):
# Forward pass
res = NonLazyVariable(self.mat_var).inv_quad(self.vec_var)
actual = self.mat_var_clone.inverse().matmul(self.vec_var_clone).mul(self.vec_var_clone).sum()
self.assert_scalar_almost_equal(res, actual, places=1)
# Backward
inv_quad_grad_output = torch.randn(1)
actual.backward(gradient=inv_quad_grad_output)
res.backward(gradient=inv_quad_grad_output)
self.assertTrue(approx_equal(self.mat_var_clone.grad.data, self.mat_var.grad.data, epsilon=1e-1))
self.assertTrue(approx_equal(self.vec_var_clone.grad.data, self.vec_var.grad.data))
def test_rotate_matrix_forward(self):
a = torch.randn(5, 5)
Q0 = torch.zeros(5, 5)
Q0[0, 4] = 1
Q0[1:, :-1] = torch.eye(4)
Q = Q0.clone()
for i in range(1, 5):
a_rotated_result = circulant.rotate(a, i)
a_rotated_actual = Q.matmul(a)
self.assertTrue(
utils.approx_equal(a_rotated_actual, a_rotated_result)
)
Q = Q.matmul(Q0)
def test_root_decomposition(self):
# Forward
root = NonLazyVariable(self.mat_var).root_decomposition()
res = root.matmul(root.transpose(-1, -2))
self.assertTrue(approx_equal(res.data, self.mat_var.data))
# Backward
res.trace().backward()
self.mat_var_clone.trace().backward()
self.assertTrue(approx_equal(self.mat_var.grad.data, self.mat_var_clone.grad.data))
with gpytorch.settings.num_trace_samples(1000):
nlv = NonLazyVariable(self.mat_var)
res_inv_quad, res_log_det = nlv.inv_quad_log_det(inv_quad_rhs=self.vecs_var, log_det=True)
self.assert_scalar_almost_equal(res_inv_quad, actual_inv_quad, places=1)
self.assert_scalar_almost_equal(res_log_det, self.log_det, places=1)
# Backward
inv_quad_grad_output = torch.Tensor([3])
log_det_grad_output = torch.Tensor([4])
actual_inv_quad.backward(gradient=inv_quad_grad_output)
self.mat_var_clone.grad.data.add_(self.mat_var_clone.data.inverse() * log_det_grad_output)
res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True)
res_log_det.backward(gradient=log_det_grad_output)
self.assertTrue(approx_equal(self.mat_var_clone.grad.data, self.mat_var.grad.data, epsilon=1e-1))
self.assertTrue(approx_equal(self.vecs_var_clone.grad.data, self.vecs_var.grad.data))
def test_circulant_inv_matmul(self):
a = torch.randn(5)
M = torch.randn(5, 5)
aM_result = circulant.circulant_inv_matmul(a, M)
C = circulant.circulant(a)
aM_actual = C.inverse().mm(M)
self.assertTrue(utils.approx_equal(aM_result, aM_actual))
def test_batch_diag(self):
block_var = torch.tensor(blocks.data, requires_grad=True)
actual_block_diagonal = torch.zeros(2, 16, 16)
for i in range(2):
for j in range(4):
actual_block_diagonal[i, j * 4 : (j + 1) * 4, j * 4 : (j + 1) * 4] = block_var[i * 4 + j]
res = BlockDiagonalLazyTensor(NonLazyTensor(block_var), n_blocks=4).diag()
actual = torch.cat([actual_block_diagonal[0].diag().unsqueeze(0), actual_block_diagonal[1].diag().unsqueeze(0)])
self.assertTrue(approx_equal(actual.data, res.data))