How to use the torchdiffeq.odeint function in torchdiffeq

To help you get started, we’ve selected a few torchdiffeq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github rtqichen / torchdiffeq / tests / api_tests.py View on Github external
            func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='dopri5')[i]
            self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adaptive_heun')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_adams(self):
        for ode in problems.PROBLEMS.keys():
            f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

            y = torchdiffeq.odeint(f, y0, t_points, method='adams')
            with self.subTest(ode=ode):
                self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_rk4(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points, method='rk4')
        self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_euler(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points, method='euler')
        self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_explicit_adams(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points[0:1], method='explicit_adams')
        self.assertLess(max_abs(sol[0] - y), error_tol)
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='rk4')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='midpoint')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
github cagatayyildiz / ODE2VAE / ode2vae_mnist_minimal.py View on Github external
h = self.encoder(X[:,0])
        qz0_m, qz0_logv = self.fc1(h), self.fc2(h) # N,2q & N,2q
        q = qz0_m.shape[1]//2
        # latent samples
        eps   = torch.randn_like(qz0_m)  # N,2q
        z0    = qz0_m + eps*torch.exp(qz0_logv) # N,2q
        logp0 = self.mvn.log_prob(eps) # N 
        # ODE
        t  = dt * torch.arange(T,dtype=torch.float).to(z0.device)
        ztL   = []
        logpL = []
        # sample L trajectories
        for l in range(L):
            f       = self.bnn.draw_f() # draw a differential function
            oderhs  = lambda t,vs: self.ode2vae_rhs(t,vs,f) # make the ODE forward function
            zt,logp = odeint(oderhs,(z0,logp0),t,method=method) # T,N,2q & T,N
            ztL.append(zt.permute([1,0,2]).unsqueeze(0)) # 1,N,T,2q
            logpL.append(logp.permute([1,0]).unsqueeze(0)) # 1,N,T
        ztL   = torch.cat(ztL,0) # L,N,T,2q
        logpL = torch.cat(logpL) # L,N,T
        # decode
        st_muL = ztL[:,:,:,q:] # L,N,T,q
        s = self.fc3(st_muL.contiguous().view([L*N*T,q]) ) # L*N*T,h_dim
        Xrec = self.decoder(s) # L*N*T,nc,d,d
        Xrec = Xrec.view([L,N,T,nc,d,d]) # L,N,T,nc,d,d
        # likelihood and elbo
        if inst_enc:
            h = self.encoder(X.contiguous().view([N*T,nc,d,d]))
            qz_enc_m, qz_enc_logv = self.fc1(h), self.fc2(h) # N*T,2q & N*T,2q
            lhood, kl_z, kl_w, inst_KL = self.elbo(qz0_m, qz0_logv, ztL, logpL, X, Xrec, L, qz_enc_m, qz_enc_logv)
            elbo = Ndata*(lhood-kl_z-inst_KL) - kl_w
        else: