How to use torchdiffeq - 10 common examples

To help you get started, we’ve selected a few torchdiffeq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github rtqichen / torchdiffeq / tests / api_tests.py View on Github external
            func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='dopri5')[i]
            self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adaptive_heun')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_adams(self):
        for ode in problems.PROBLEMS.keys():
            f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

            y = torchdiffeq.odeint(f, y0, t_points, method='adams')
            with self.subTest(ode=ode):
                self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_rk4(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points, method='rk4')
        self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_euler(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points, method='euler')
        self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_explicit_adams(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points[0:1], method='explicit_adams')
        self.assertLess(max_abs(sol[0] - y), error_tol)
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='rk4')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='midpoint')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_adjoint(self):
        for ode in problems.PROBLEMS.keys():
            f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

            y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
            with self.subTest(ode=ode):
                self.assertLess(rel_error(sol, y), error_tol)