Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_module_clone(self):
original_output = self.model(self.input)
original_loss = self.loss_func(original_output, th.tensor([[0., 0.]]))
original_gradients = th.autograd.grad(original_loss,
self.model.parameters(),
retain_graph=True,
create_graph=True)
cloned_model = l2l.clone_module(self.model)
self.optimizer_step(self.model, original_gradients)
cloned_output = cloned_model(self.input)
cloned_loss = self.loss_func(cloned_output, th.tensor([[0., 0.]]))
cloned_gradients = th.autograd.grad(cloned_loss,
cloned_model.parameters(),
retain_graph=True,
create_graph=True)
self.optimizer_step(cloned_model, cloned_gradients)
for a, b in zip(self.model.parameters(), cloned_model.parameters()):
assert th.equal(a, b)
def meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr):
mean_loss = 0.0
mean_kl = 0.0
for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies),
total=len(iteration_replays),
desc='Surrogate Loss',
leave=False):
train_replays = task_replays[:-1]
valid_episodes = task_replays[-1]
new_policy = l2l.clone_module(policy)
# Fast Adapt
for train_episodes in train_replays:
new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr,
baseline, gamma, tau, first_order=False)
# Useful values
states = valid_episodes.state()
actions = valid_episodes.action()
next_states = valid_episodes.next_state()
rewards = valid_episodes.reward()
dones = valid_episodes.done()
# Compute KL
old_densities = old_policy.density(states)
new_densities = new_policy.density(states)