How to use the catalyst.rl.utils.any2device function in catalyst

To help you get started, we’ve selected a few catalyst examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github catalyst-team / catalyst / catalyst / rl / offpolicy / algorithms / actor_critic.py View on Github external
def train(self, batch, actor_update=True, critic_update=True):
        states_t, actions_t, rewards_t, states_tp1, done_t = \
            batch["state"], batch["action"], batch["reward"], \
            batch["next_state"], batch["done"]

        states_t = utils.any2device(states_t, device=self._device)
        actions_t = utils.any2device(actions_t, device=self._device)
        rewards_t = utils.any2device(
            rewards_t, device=self._device
        ).unsqueeze(1)
        states_tp1 = utils.any2device(states_tp1, device=self._device)
        done_t = utils.any2device(done_t, device=self._device).unsqueeze(1)
        """
        states_t: [bs; history_len; observation_len]
        actions_t: [bs; action_len]
        rewards_t: [bs; 1]
        states_tp1: [bs; history_len; observation_len]
        done_t: [bs; 1]
        """

        policy_loss, value_loss = self._loss_fn(
            states_t, actions_t, rewards_t, states_tp1, done_t
github catalyst-team / catalyst / catalyst / rl / onpolicy / algorithms / ppo.py View on Github external
def get_rollout(self, states, actions, rewards, dones):
        assert len(states) == len(actions) == len(rewards) == len(dones)

        trajectory_len = \
            rewards.shape[0] if dones[-1] else rewards.shape[0] - 1
        states_len = states.shape[0]

        states = utils.any2device(states, device=self._device)
        actions = utils.any2device(actions, device=self._device)
        rewards = np.array(rewards)[:trajectory_len]
        values = torch.zeros(
            (states_len + 1, self._num_heads, self._num_atoms)).\
            to(self._device)
        values[:states_len, ...] = self.critic(states).squeeze_(dim=2)
        # Each column corresponds to a different gamma
        values = values.cpu().numpy()[:trajectory_len + 1, ...]
        _, logprobs = self.actor(states, logprob=actions)
        logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len]
        # len x num_heads
        deltas = rewards[:, None, None] \
            + self._gammas[:, None] * values[1:] - values[:-1]

        # For each gamma in the list of gammas compute the
        # advantage and returns
        # len x num_heads x num_atoms
github catalyst-team / catalyst / catalyst / rl / offpolicy / algorithms / ddpg.py View on Github external
values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
            z = torch.linspace(
                start=self.v_min, end=self.v_max, steps=self.num_atoms
            )
            self.z = utils.any2device(z, device=self._device)
            self._loss_fn = self._categorical_loss
        elif critic_distribution == "quantile":
            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self.num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(
                start=tau_min, end=tau_max, steps=self.num_atoms
            )
            self.tau = utils.any2device(tau, device=self._device)
            self._loss_fn = self._quantile_loss
        else:
            assert self.critic_criterion is not None
github catalyst-team / catalyst / catalyst / rl / offpolicy / algorithms / td3.py View on Github external
self._gamma,
                self._hyperbolic_constant,
                self._num_heads
            )
        self._gammas = utils.any2device(self._gammas, device=self._device)
        assert critic_distribution in [None, "categorical", "quantile"]

        if critic_distribution == "categorical":
            self.num_atoms = self.critic.num_atoms
            values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
            z = torch.linspace(
                start=self.v_min, end=self.v_max, steps=self.num_atoms
            )
            self.z = utils.any2device(z, device=self._device)
            self._loss_fn = self._categorical_loss
        elif critic_distribution == "quantile":
            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self.num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(
                start=tau_min, end=tau_max, steps=self.num_atoms
            )
            self.tau = utils.any2device(tau, device=self._device)
            self._loss_fn = self._quantile_loss
        else:
            assert self.critic_criterion is not None
github catalyst-team / catalyst / catalyst / rl / onpolicy / algorithms / a2c.py View on Github external
(
            states_t, actions_t, returns_t,
            advantages_t, action_logprobs_t
        ) = (
            batch["state"], batch["action"], batch["return"],
            batch["advantage"], batch["action_logprob"]
        )

        states_t = utils.any2device(states_t, device=self._device)
        actions_t = utils.any2device(actions_t, device=self._device)
        returns_t = utils.any2device(
            returns_t, device=self._device
        ).unsqueeze_(-1)

        advantages_t = utils.any2device(advantages_t, device=self._device)
        action_logprobs_t = utils.any2device(
            action_logprobs_t, device=self._device
        )

        action_logprobs_t = utils.any2device(
            action_logprobs_t, device=self._device
        )

        # critic loss
        values_tp0 = self.critic(states_t).squeeze_(dim=2)
        advantages_tp0 = (returns_t - values_tp0)
        value_loss = 0.5 * advantages_tp0.pow(2).mean()

        # actor loss
        _, action_logprobs_tp0 = self.actor(states_t, logprob=actions_t)
        policy_loss = -(advantages_t.detach() * action_logprobs_tp0).mean()
github catalyst-team / catalyst / catalyst / rl / onpolicy / algorithms / ppo.py View on Github external
self.clip_eps = clip_eps
        self.entropy_regularization = entropy_regularization

        critic_distribution = self.critic.distribution
        self._value_loss_fn = self._base_value_loss
        self._num_atoms = self.critic.num_atoms
        self._num_heads = self.critic.num_heads
        self._hyperbolic_constant = self.critic.hyperbolic_constant
        self._gammas = \
            utils.hyperbolic_gammas(
                self._gamma,
                self._hyperbolic_constant,
                self._num_heads
            )
        # 1 x num_heads x 1
        self._gammas_torch = utils.any2device(
            self._gammas, device=self._device
        )[None, :, None]

        if critic_distribution == "categorical":
            self.num_atoms = self.critic.num_atoms
            values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self._num_atoms - 1)
            z = torch.linspace(
                start=self.v_min, end=self.v_max, steps=self._num_atoms
            )
            self.z = utils.any2device(z, device=self._device)
            self._value_loss_fn = self._categorical_value_loss
        elif critic_distribution == "quantile":
            assert self.critic_criterion is not None
github catalyst-team / catalyst / catalyst / rl / onpolicy / algorithms / ppo.py View on Github external
def train(self, batch, **kwargs):
        (
            states_t, actions_t, returns_t, states_tp1, done_t, values_t,
            advantages_t, action_logprobs_t
        ) = (
            batch["state"], batch["action"], batch["return"],
            batch["state_tp1"], batch["done"], batch["value"],
            batch["advantage"], batch["action_logprob"]
        )

        states_t = utils.any2device(states_t, device=self._device)
        actions_t = utils.any2device(actions_t, device=self._device)
        returns_t = utils.any2device(
            returns_t, device=self._device
        ).unsqueeze_(-1)
        states_tp1 = utils.any2device(states_tp1, device=self._device)
        done_t = utils.any2device(done_t, device=self._device)[:, None, None]
        # done_t = done_t[:, None, :]  # [bs; 1; 1]

        values_t = utils.any2device(values_t, device=self._device)
        advantages_t = utils.any2device(advantages_t, device=self._device)
        action_logprobs_t = utils.any2device(
            action_logprobs_t, device=self._device
        )

        # critic loss
        # states_t - [bs; {state_shape}]
        # values_t - [bs; num_heads; num_atoms]
github catalyst-team / catalyst / catalyst / rl / offpolicy / algorithms / td3.py View on Github external
values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
            z = torch.linspace(
                start=self.v_min, end=self.v_max, steps=self.num_atoms
            )
            self.z = utils.any2device(z, device=self._device)
            self._loss_fn = self._categorical_loss
        elif critic_distribution == "quantile":
            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self.num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(
                start=tau_min, end=tau_max, steps=self.num_atoms
            )
            self.tau = utils.any2device(tau, device=self._device)
            self._loss_fn = self._quantile_loss
        else:
            assert self.critic_criterion is not None
github catalyst-team / catalyst / catalyst / rl / onpolicy / algorithms / reinforce.py View on Github external
def train(self, batch, **kwargs):
        states, actions, returns, action_logprobs = \
            batch["state"], batch["action"], batch["return"],\
            batch["action_logprob"]

        states = utils.any2device(states, device=self._device)
        actions = utils.any2device(actions, device=self._device)
        returns = utils.any2device(returns, device=self._device)
        old_logprobs = utils.any2device(action_logprobs, device=self._device)

        # actor loss
        _, logprobs = self.actor(states, logprob=actions)

        # REINFORCE objective function
        policy_loss = -torch.mean(logprobs * returns)

        if self.entropy_regularization is not None:
            entropy = -(torch.exp(logprobs) * logprobs).mean()
            entropy_loss = self.entropy_regularization * entropy
            policy_loss = policy_loss + entropy_loss

        # actor update