Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='dopri5')[i]
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adaptive_heun')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
def test_adams(self):
for ode in problems.PROBLEMS.keys():
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points, method='adams')
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)
def test_rk4(self):
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points, method='rk4')
self.assertLess(rel_error(sol, y), error_tol)
def test_euler(self):
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points, method='euler')
self.assertLess(rel_error(sol, y), error_tol)
def test_explicit_adams(self):
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points[0:1], method='explicit_adams')
self.assertLess(max_abs(sol[0] - y), error_tol)
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='rk4')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='midpoint')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
def test_adjoint(self):
for ode in problems.PROBLEMS.keys():
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)