How to use the algorithms.common.helper_functions.soft_update function in algorithms

To help you get started, we’ve selected a few algorithms examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github medipixel / rl_algorithms / algorithms / ddpg / agent.py View on Github external
nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
        self.critic_optimizer.step()

        # train actor
        gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"]
        actions = self.actor(states)
        actor_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
        self.actor_optimizer.step()

        # update target networks
        tau = self.hyper_params["TAU"]
        common_utils.soft_update(self.actor, self.actor_target, tau)
        common_utils.soft_update(self.critic, self.critic_target, tau)

        return actor_loss.item(), critic_loss.item()
github medipixel / rl_algorithms / algorithms / bc / sac_agent.py View on Github external
std_reg = self.hyper_params["W_STD_REG"] * std.pow(2).mean()
                pre_activation_reg = self.hyper_params["W_PRE_ACTIVATION_REG"] * (
                    pre_tanh_value.pow(2).sum(dim=-1).mean()
                )
                actor_reg = mean_reg + std_reg + pre_activation_reg

                # actor loss + regularization
                actor_loss += actor_reg

            # train actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # update target networks
            common_utils.soft_update(self.vf, self.vf_target, self.hyper_params["TAU"])
        else:
            actor_loss = torch.zeros(1)
            n_qf_mask = 0

        return (
            actor_loss.item(),
            qf_1_loss.item(),
            qf_2_loss.item(),
            vf_loss.item(),
            alpha_loss.item(),
            n_qf_mask,
        )
github medipixel / rl_algorithms / algorithms / fd / ddpg_agent.py View on Github external
self.critic_optimizer.step()

        # train actor
        gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"]
        actions = self.actor(states)
        actor_loss_element_wise = -self.critic(torch.cat((states, actions), dim=-1))
        actor_loss = torch.mean(actor_loss_element_wise * weights)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
        self.actor_optimizer.step()

        # update target networks
        tau = self.hyper_params["TAU"]
        common_utils.soft_update(self.actor, self.actor_target, tau)
        common_utils.soft_update(self.critic, self.critic_target, tau)

        # update priorities
        new_priorities = critic_loss_element_wise
        new_priorities += self.hyper_params["LAMBDA3"] * actor_loss_element_wise.pow(2)
        new_priorities += self.hyper_params["PER_EPS"]
        new_priorities = new_priorities.data.cpu().numpy().squeeze()
        new_priorities += eps_d
        self.memory.update_priorities(indices, new_priorities)

        # increase beta
        fraction = min(float(self.i_episode) / self.args.episode_num, 1.0)
        self.beta = self.beta + fraction * (1.0 - self.beta)

        return actor_loss.item(), critic_loss.item()
github medipixel / rl_algorithms / algorithms / fd / ddpg_agent.py View on Github external
nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
        self.critic_optimizer.step()

        # train actor
        gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"]
        actions = self.actor(states)
        actor_loss_element_wise = -self.critic(torch.cat((states, actions), dim=-1))
        actor_loss = torch.mean(actor_loss_element_wise * weights)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
        self.actor_optimizer.step()

        # update target networks
        tau = self.hyper_params["TAU"]
        common_utils.soft_update(self.actor, self.actor_target, tau)
        common_utils.soft_update(self.critic, self.critic_target, tau)

        # update priorities
        new_priorities = critic_loss_element_wise
        new_priorities += self.hyper_params["LAMBDA3"] * actor_loss_element_wise.pow(2)
        new_priorities += self.hyper_params["PER_EPS"]
        new_priorities = new_priorities.data.cpu().numpy().squeeze()
        new_priorities += eps_d
        self.memory.update_priorities(indices, new_priorities)

        # increase beta
        fraction = min(float(self.i_episode) / self.args.episode_num, 1.0)
        self.beta = self.beta + fraction * (1.0 - self.beta)

        return actor_loss.item(), critic_loss.item()
github medipixel / rl_algorithms / algorithms / td3 / agent.py View on Github external
if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0:
            # policy loss
            actions = self.actor(states)
            actor_loss = -self.critic1(states, actions).mean()

            # train actor
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            # update target networks
            tau = self.hyper_params["TAU"]
            common_utils.soft_update(self.critic1, self.critic_target1, tau)
            common_utils.soft_update(self.critic2, self.critic_target2, tau)
            common_utils.soft_update(self.actor, self.actor_target, tau)
        else:
            actor_loss = torch.zeros(1)

        return actor_loss.item(), critic1_loss.item(), critic2_loss.item()
github medipixel / rl_algorithms / algorithms / td3 / agent.py View on Github external
critic_loss.backward()
        self.critic_optim.step()

        if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0:
            # policy loss
            actions = self.actor(states)
            actor_loss = -self.critic1(states, actions).mean()

            # train actor
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            # update target networks
            tau = self.hyper_params["TAU"]
            common_utils.soft_update(self.critic1, self.critic_target1, tau)
            common_utils.soft_update(self.critic2, self.critic_target2, tau)
            common_utils.soft_update(self.actor, self.actor_target, tau)
        else:
            actor_loss = torch.zeros(1)

        return actor_loss.item(), critic1_loss.item(), critic2_loss.item()