Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
f0 = tuple(k_[0] for k_ in k)
f1 = tuple(k_[-1] for k_ in k)
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
def _abs_square(x):
return torch.mul(x, x)
def _ta_append(list_of_tensors, value):
"""Append a value to the end of a list of PyTorch tensors."""
list_of_tensors.append(value)
return list_of_tensors
class AdaptiveHeunSolver(AdaptiveStepsizeODESolver):
def __init__(
self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
**unused_kwargs
):
_handle_unused_kwargs(self, unused_kwargs)
del unused_kwargs
self.func = func
self.y0 = y0
self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
self.first_step = first_step
self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
f0 = tuple(k_[0] for k_ in k)
f1 = tuple(k_[-1] for k_ in k)
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
def _abs_square(x):
return torch.mul(x, x)
def _ta_append(list_of_tensors, value):
"""Append a value to the end of a list of PyTorch tensors."""
list_of_tensors.append(value)
return list_of_tensors
class Dopri5Solver(AdaptiveStepsizeODESolver):
def __init__(
self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
**unused_kwargs
):
_handle_unused_kwargs(self, unused_kwargs)
del unused_kwargs
self.func = func
self.y0 = y0
self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
self.first_step = first_step
self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)