How to use the torchdiffeq._impl.rk_common._RungeKuttaState function in torchdiffeq

To help you get started, we’ve selected a few torchdiffeq examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github rtqichen / torchdiffeq / torchdiffeq / _impl / adaptive_heun.py View on Github external
#                     Error Ratio                      #
        ########################################################
        mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
        accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all()

        ########################################################
        #                   Update RK State                    #
        ########################################################
        y_next = y1 if accept_step else y0
        f_next = f1 if accept_step else f0
        t_next = t0 + dt if accept_step else t0
        interp_coeff = _interp_fit_adaptive_heun(y0, y1, k, dt) if accept_step else interp_coeff
        dt_next = _optimal_step_size(
            dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5
        )
        rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
        return rk_state
github rtqichen / torchdiffeq / torchdiffeq / _impl / adaptive_heun.py View on Github external
def before_integrate(self, t):
        f0 = self.func(t[0].type_as(self.y0[0]), self.y0)
        if self.first_step is None:
            first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0).to(t)
        else:
            first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
        self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
github uncbiag / easyreg / torchdiffeq / _impl / dopri5.py View on Github external
def before_integrate(self, t):
        f0 = self.func(t[0].type_as(self.y0[0]), self.y0)
        #print("first_step is {}".format(self.first_step))
        if self.first_step is None:
            first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0).to(t)
        else:
            first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
        # if first_step>0.2:
        #     print("warning the first step of dopri5 {} is too big, set to 0.2".format(first_step))
        #     first_step = _convert_to_tensor(0.2, dtype=torch.float64, device=self.y0[0].device)

        self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
github uncbiag / easyreg / torchdiffeq / _impl / dopri5.py View on Github external
y_next = y1 if accept_step else y0
            f_next = f1 if accept_step else f0
            t_next = t0 + dt if accept_step else t0
            interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff
        else:
            if dt_next<0.02:
                print("warning the step of dopri5 {} is too small, set to 0.01".format(dt_next))
                dt_next = _convert_to_tensor(0.01, dtype=torch.float64, device=y0[0].device)
            if dt_next>0.1:
                print("warning the step of dopri5 {} is too big, set to 0.1".format(dt_next))
                dt_next = _convert_to_tensor(0.1, dtype=torch.float64, device=y0[0].device)
            y_next = y1
            f_next = f1
            t_next = t0 + dt
            interp_coeff = _interp_fit_dopri5(y0, y1, k, dt)
        rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
        return rk_state