How to use the torchdiffeq.odeint_adjoint function in torchdiffeq

To help you get started, we’ve selected a few torchdiffeq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_adjoint(self):
        for ode in problems.PROBLEMS.keys():
            f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

            y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
            with self.subTest(ode=ode):
                self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
        func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
        ys = func(y0, t_points)
github rtqichen / torchdiffeq / tests / odeint_tests.py View on Github external
def test_adjoint(self):
        for ode in problems.PROBLEMS.keys():
            f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

            y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
            with self.subTest(ode=ode):
                self.assertLess(rel_error(sol, y), error_tol)
github rtqichen / torchdiffeq / tests / gradient_tests.py View on Github external
def test_adams_adjoint_against_dopri5(self):
        func, y0, t_points = self.problem()
        ys_ = torchdiffeq.odeint_adjoint(func, y0, t_points, method='adams')
        gradys = torch.rand_like(ys_) * 0.1
        ys_.backward(gradys)

        adj_y0_grad = y0.grad
        adj_t_grad = t_points.grad
        adj_A_grad = func.A.grad
        self.assertEqual(max_abs(func.unused_module.weight.grad), 0)
        self.assertEqual(max_abs(func.unused_module.bias.grad), 0)

        func, y0, t_points = self.problem()
        ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5')
        ys.backward(gradys)

        self.assertLess(max_abs(y0.grad - adj_y0_grad), 5e-2)
        self.assertLess(max_abs(t_points.grad - adj_t_grad), 5e-4)
        self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-2)
github rtqichen / torchdiffeq / examples / ode_demo.py View on Github external
if __name__ == '__main__':

    ii = 0

    func = ODEFunc()
    optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
    end = time.time()

    time_meter = RunningAverageMeter(0.97)
    loss_meter = RunningAverageMeter(0.97)

    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()
        batch_y0, batch_t, batch_y = get_batch()
        pred_y = odeint(func, batch_y0, batch_t)
        loss = torch.mean(torch.abs(pred_y - batch_y))
        loss.backward()
        optimizer.step()

        time_meter.update(time.time() - end)
        loss_meter.update(loss.item())

        if itr % args.test_freq == 0:
            with torch.no_grad():
                pred_y = odeint(func, true_y0, t)
                loss = torch.mean(torch.abs(pred_y - true_y))
                print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
                visualize(true_y, pred_y, func, ii)
                ii += 1

        end = time.time()
github DIAGNijmegen / neural-odes-segmentation / model_utils.py View on Github external
def forward(self, x, eval_times=None):
        # Forward pass corresponds to solving ODE, so reset number of function
        # evaluations counter
        self.odefunc.nfe = 0

        if eval_times is None:
            integration_time = torch.tensor([0, 1]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)

        if self.adjoint:
            out = odeint_adjoint(self.odefunc, x, integration_time,
                                 rtol=self.tol, atol=self.tol, method='dopri5',
                                 options={'max_num_steps': MAX_NUM_STEPS})
        else:
            out = odeint(self.odefunc, x, integration_time,
                         rtol=self.tol, atol=self.tol, method='dopri5',
                         options={'max_num_steps': MAX_NUM_STEPS})

        if eval_times is None:
            return out[1]  # Return only final time
        else:
            return out
github rtqichen / ffjord / vae_lib / models / CNFVAE.py View on Github external
def forward(self, x):

        self.log_det_j = 0.

        z_mu, z_var, am_params = self.encode(x)

        # Sample z_0
        z0 = self.reparameterize(z_mu, z_var)

        delta_logp = torch.zeros(x.shape[0], 1).to(x)
        z = z0
        for odefunc, am_param in zip(self.odefuncs, am_params):
            am_param_unpacked = odefunc.diffeq._unpack_params(am_param)
            odefunc.before_odeint()
            states = odeint(
                odefunc,
                (z, delta_logp) + tuple(am_param_unpacked),
                self.integration_times.to(z),
                atol=self.atol,
                rtol=self.rtol,
                method=self.solver,
            )
            z, delta_logp = states[0][-1], states[1][-1]

        x_mean = self.decode(z)

        return x_mean, z_mu, z_var, -delta_logp.view(-1), z0, z
github rtqichen / ffjord / lib / layers / cnf.py View on Github external
# Add regularization states.
        reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))

        if self.training:
            state_t = odeint(
                self.odefunc,
                (z, _logpz) + reg_states,
                integration_times.to(z),
                atol=[self.atol, self.atol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.atol,
                rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.rtol,
                method=self.solver,
                options=self.solver_options,
            )
        else:
            state_t = odeint(
                self.odefunc,
                (z, _logpz),
                integration_times.to(z),
                atol=self.test_atol,
                rtol=self.test_rtol,
                method=self.test_solver,
            )

        if len(integration_times) == 2:
            state_t = tuple(s[1] for s in state_t)

        z_t, logpz_t = state_t[:2]
        self.regularization_states = state_t[2:]

        if logpz is not None:
            return z_t, logpz_t
github rtqichen / torchdiffeq / examples / ode_demo.py View on Github external
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])


class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.mm(y**3, true_A)


with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')


def get_batch():
    s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:args.batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)
    return batch_y0, batch_t, batch_y


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


if args.viz:
github rtqichen / ffjord / lib / layers / cnf.py View on Github external
else:
            _logpz = logpz

        if integration_times is None:
            integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z)
        if reverse:
            integration_times = _flip(integration_times, 0)

        # Refresh the odefunc statistics.
        self.odefunc.before_odeint()

        # Add regularization states.
        reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))

        if self.training:
            state_t = odeint(
                self.odefunc,
                (z, _logpz) + reg_states,
                integration_times.to(z),
                atol=[self.atol, self.atol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.atol,
                rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.rtol,
                method=self.solver,
                options=self.solver_options,
            )
        else:
            state_t = odeint(
                self.odefunc,
                (z, _logpz),
                integration_times.to(z),
                atol=self.test_atol,
                rtol=self.test_rtol,
                method=self.test_solver,