How to use the learn2learn.clone_module function in learn2learn

To help you get started, we’ve selected a few learn2learn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github learnables / learn2learn / tests / unit / util_tests.py View on Github external
def test_module_clone(self):
        original_output = self.model(self.input)
        original_loss = self.loss_func(original_output, th.tensor([[0., 0.]]))
        original_gradients = th.autograd.grad(original_loss,
                                              self.model.parameters(),
                                              retain_graph=True,
                                              create_graph=True)

        cloned_model = l2l.clone_module(self.model)
        self.optimizer_step(self.model, original_gradients)

        cloned_output = cloned_model(self.input)
        cloned_loss = self.loss_func(cloned_output, th.tensor([[0., 0.]]))

        cloned_gradients = th.autograd.grad(cloned_loss,
                                            cloned_model.parameters(),
                                            retain_graph=True,
                                            create_graph=True)

        self.optimizer_step(cloned_model, cloned_gradients)

        for a, b in zip(self.model.parameters(), cloned_model.parameters()):
            assert th.equal(a, b)
github learnables / learn2learn / examples / rl / maml_trpo.py View on Github external
def meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr):
    mean_loss = 0.0
    mean_kl = 0.0
    for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies),
                                         total=len(iteration_replays),
                                         desc='Surrogate Loss',
                                         leave=False):
        train_replays = task_replays[:-1]
        valid_episodes = task_replays[-1]
        new_policy = l2l.clone_module(policy)

        # Fast Adapt
        for train_episodes in train_replays:
            new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr,
                                        baseline, gamma, tau, first_order=False)

        # Useful values
        states = valid_episodes.state()
        actions = valid_episodes.action()
        next_states = valid_episodes.next_state()
        rewards = valid_episodes.reward()
        dones = valid_episodes.done()

        # Compute KL
        old_densities = old_policy.density(states)
        new_densities = new_policy.density(states)