Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_adjoint(self):
for ode in problems.PROBLEMS.keys():
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)
func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
ys = func(y0, t_points)
def test_adjoint(self):
for ode in problems.PROBLEMS.keys():
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)
def test_adams_adjoint_against_dopri5(self):
func, y0, t_points = self.problem()
ys_ = torchdiffeq.odeint_adjoint(func, y0, t_points, method='adams')
gradys = torch.rand_like(ys_) * 0.1
ys_.backward(gradys)
adj_y0_grad = y0.grad
adj_t_grad = t_points.grad
adj_A_grad = func.A.grad
self.assertEqual(max_abs(func.unused_module.weight.grad), 0)
self.assertEqual(max_abs(func.unused_module.bias.grad), 0)
func, y0, t_points = self.problem()
ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5')
ys.backward(gradys)
self.assertLess(max_abs(y0.grad - adj_y0_grad), 5e-2)
self.assertLess(max_abs(t_points.grad - adj_t_grad), 5e-4)
self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-2)
if __name__ == '__main__':
ii = 0
func = ODEFunc()
optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
end = time.time()
time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)
for itr in range(1, args.niters + 1):
optimizer.zero_grad()
batch_y0, batch_t, batch_y = get_batch()
pred_y = odeint(func, batch_y0, batch_t)
loss = torch.mean(torch.abs(pred_y - batch_y))
loss.backward()
optimizer.step()
time_meter.update(time.time() - end)
loss_meter.update(loss.item())
if itr % args.test_freq == 0:
with torch.no_grad():
pred_y = odeint(func, true_y0, t)
loss = torch.mean(torch.abs(pred_y - true_y))
print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
visualize(true_y, pred_y, func, ii)
ii += 1
end = time.time()
def forward(self, x, eval_times=None):
# Forward pass corresponds to solving ODE, so reset number of function
# evaluations counter
self.odefunc.nfe = 0
if eval_times is None:
integration_time = torch.tensor([0, 1]).float().type_as(x)
else:
integration_time = eval_times.type_as(x)
if self.adjoint:
out = odeint_adjoint(self.odefunc, x, integration_time,
rtol=self.tol, atol=self.tol, method='dopri5',
options={'max_num_steps': MAX_NUM_STEPS})
else:
out = odeint(self.odefunc, x, integration_time,
rtol=self.tol, atol=self.tol, method='dopri5',
options={'max_num_steps': MAX_NUM_STEPS})
if eval_times is None:
return out[1] # Return only final time
else:
return out
def forward(self, x):
self.log_det_j = 0.
z_mu, z_var, am_params = self.encode(x)
# Sample z_0
z0 = self.reparameterize(z_mu, z_var)
delta_logp = torch.zeros(x.shape[0], 1).to(x)
z = z0
for odefunc, am_param in zip(self.odefuncs, am_params):
am_param_unpacked = odefunc.diffeq._unpack_params(am_param)
odefunc.before_odeint()
states = odeint(
odefunc,
(z, delta_logp) + tuple(am_param_unpacked),
self.integration_times.to(z),
atol=self.atol,
rtol=self.rtol,
method=self.solver,
)
z, delta_logp = states[0][-1], states[1][-1]
x_mean = self.decode(z)
return x_mean, z_mu, z_var, -delta_logp.view(-1), z0, z
# Add regularization states.
reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))
if self.training:
state_t = odeint(
self.odefunc,
(z, _logpz) + reg_states,
integration_times.to(z),
atol=[self.atol, self.atol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.atol,
rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.rtol,
method=self.solver,
options=self.solver_options,
)
else:
state_t = odeint(
self.odefunc,
(z, _logpz),
integration_times.to(z),
atol=self.test_atol,
rtol=self.test_rtol,
method=self.test_solver,
)
if len(integration_times) == 2:
state_t = tuple(s[1] for s in state_t)
z_t, logpz_t = state_t[:2]
self.regularization_states = state_t[2:]
if logpz is not None:
return z_t, logpz_t
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])
class Lambda(nn.Module):
def forward(self, t, y):
return torch.mm(y**3, true_A)
with torch.no_grad():
true_y = odeint(Lambda(), true_y0, t, method='dopri5')
def get_batch():
s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
batch_y0 = true_y[s] # (M, D)
batch_t = t[:args.batch_time] # (T)
batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
return batch_y0, batch_t, batch_y
def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
if args.viz:
else:
_logpz = logpz
if integration_times is None:
integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z)
if reverse:
integration_times = _flip(integration_times, 0)
# Refresh the odefunc statistics.
self.odefunc.before_odeint()
# Add regularization states.
reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))
if self.training:
state_t = odeint(
self.odefunc,
(z, _logpz) + reg_states,
integration_times.to(z),
atol=[self.atol, self.atol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.atol,
rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.rtol,
method=self.solver,
options=self.solver_options,
)
else:
state_t = odeint(
self.odefunc,
(z, _logpz),
integration_times.to(z),
atol=self.test_atol,
rtol=self.test_rtol,
method=self.test_solver,