How to use the tianshou.data.to_torch_as function in tianshou

To help you get started, we’ve selected a few tianshou examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github thu-ml / tianshou / tianshou / policy / modelfree / ppo.py View on Github external
def learn(self, batch: Batch, batch_size: int, repeat: int,
              **kwargs) -> Dict[str, List[float]]:
        self._batch = batch_size
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        v = []
        old_log_prob = []
        with torch.no_grad():
            for b in batch.split(batch_size, shuffle=False):
                v.append(self.critic(b.obs))
                old_log_prob.append(self(b).dist.log_prob(
                    to_torch_as(b.act, v[0])))
        batch.v = torch.cat(v, dim=0).squeeze(-1)  # old value
        batch.act = to_torch_as(batch.act, v[0])
        batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
        batch.returns = to_torch_as(batch.returns, v[0])
        if self._rew_norm:
            mean, std = batch.returns.mean(), batch.returns.std()
            if not np.isclose(std.item(), 0):
                batch.returns = (batch.returns - mean) / std
        batch.adv = batch.returns - batch.v
        if self._rew_norm:
            mean, std = batch.adv.mean(), batch.adv.std()
            if not np.isclose(std.item(), 0):
                batch.adv = (batch.adv - mean) / std
        for _ in range(repeat):
            for b in batch.split(batch_size):
                dist = self(b).dist
                value = self.critic(b.obs).squeeze(-1)
                ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
github thu-ml / tianshou / tianshou / policy / modelfree / pg.py View on Github external
def learn(self, batch: Batch, batch_size: int, repeat: int,
              **kwargs) -> Dict[str, List[float]]:
        losses = []
        r = batch.returns
        if self._rew_norm and not np.isclose(r.std(), 0):
            batch.returns = (r - r.mean()) / r.std()
        for _ in range(repeat):
            for b in batch.split(batch_size):
                self.optim.zero_grad()
                dist = self(b).dist
                a = to_torch_as(b.act, dist.logits)
                r = to_torch_as(b.returns, dist.logits)
                loss = -(dist.log_prob(a) * r).sum()
                loss.backward()
                self.optim.step()
                losses.append(loss.item())
        return {'loss': losses}
github thu-ml / tianshou / tianshou / policy / modelfree / ppo.py View on Github external
def learn(self, batch: Batch, batch_size: int, repeat: int,
              **kwargs) -> Dict[str, List[float]]:
        self._batch = batch_size
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        v = []
        old_log_prob = []
        with torch.no_grad():
            for b in batch.split(batch_size, shuffle=False):
                v.append(self.critic(b.obs))
                old_log_prob.append(self(b).dist.log_prob(
                    to_torch_as(b.act, v[0])))
        batch.v = torch.cat(v, dim=0).squeeze(-1)  # old value
        batch.act = to_torch_as(batch.act, v[0])
        batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
        batch.returns = to_torch_as(batch.returns, v[0])
        if self._rew_norm:
            mean, std = batch.returns.mean(), batch.returns.std()
            if not np.isclose(std.item(), 0):
                batch.returns = (batch.returns - mean) / std
        batch.adv = batch.returns - batch.v
        if self._rew_norm:
            mean, std = batch.adv.mean(), batch.adv.std()
            if not np.isclose(std.item(), 0):
                batch.adv = (batch.adv - mean) / std
        for _ in range(repeat):
            for b in batch.split(batch_size):
                dist = self(b).dist
                value = self.critic(b.obs).squeeze(-1)
                ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
                         ).exp().float()
                surr1 = ratio * b.adv
github thu-ml / tianshou / tianshou / policy / modelfree / a2c.py View on Github external
def learn(self, batch: Batch, batch_size: int, repeat: int,
              **kwargs) -> Dict[str, List[float]]:
        self._batch = batch_size
        r = batch.returns
        if self._rew_norm and not np.isclose(r.std(), 0):
            batch.returns = (r - r.mean()) / r.std()
        losses, actor_losses, vf_losses, ent_losses = [], [], [], []
        for _ in range(repeat):
            for b in batch.split(batch_size):
                self.optim.zero_grad()
                dist = self(b).dist
                v = self.critic(b.obs).squeeze(-1)
                a = to_torch_as(b.act, v)
                r = to_torch_as(b.returns, v)
                a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
                           ).mean()
                vf_loss = F.mse_loss(r, v)
                ent_loss = dist.entropy().mean()
                loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
                loss.backward()
                if self._grad_norm is not None:
                    nn.utils.clip_grad_norm_(
                        list(self.actor.parameters()) +
                        list(self.critic.parameters()),
                        max_norm=self._grad_norm)
                self.optim.step()
                actor_losses.append(a_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
github thu-ml / tianshou / tianshou / policy / base.py View on Github external
mean, std = 0, 1
        else:
            mean, std = 0, 1
        returns = np.zeros_like(indice)
        gammas = np.zeros_like(indice) + n_step
        done, buf_len = buffer.done, len(buffer)
        for n in range(n_step - 1, -1, -1):
            now = (indice + n) % buf_len
            gammas[done[now] > 0] = n
            returns[done[now] > 0] = 0
            returns = (rew[now] - mean) / std + gamma * returns
        terminal = (indice + n_step - 1) % buf_len
        target_q = target_q_fn(buffer, terminal).flatten()  # shape: [bsz, ]
        target_q[gammas != n_step] = 0
        returns = to_torch_as(returns, target_q)
        gammas = to_torch_as(gamma ** gammas, target_q)
        batch.returns = target_q * gammas + returns
        return batch
github thu-ml / tianshou / tianshou / policy / modelfree / ddpg.py View on Github external
:return: A :class:`~tianshou.data.Batch` which has 2 keys:

            * ``act`` the action.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """
        model = getattr(self, model)
        obs = getattr(batch, input)
        actions, h = model(obs, state=state, info=batch.info)
        actions += self._action_bias
        if self.training and explorating:
            actions += to_torch_as(self._noise(actions.shape), actions)
        actions = actions.clamp(self._range[0], self._range[1])
        return Batch(act=actions, state=h)
github thu-ml / tianshou / tianshou / policy / modelfree / dqn.py View on Github external
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
        if self._target and self._cnt % self._freq == 0:
            self.sync_weight()
        self.optim.zero_grad()
        q = self(batch).logits
        q = q[np.arange(len(q)), batch.act]
        r = to_torch_as(batch.returns, q)
        if hasattr(batch, 'update_weight'):
            td = r - q
            batch.update_weight(batch.indice, to_numpy(td))
            impt_weight = to_torch_as(batch.impt_weight, q)
            loss = (td.pow(2) * impt_weight).mean()
        else:
            loss = F.mse_loss(q, r)
        loss.backward()
        self.optim.step()
        self._cnt += 1
        return {'loss': loss.item()}
github thu-ml / tianshou / tianshou / policy / base.py View on Github external
if np.isclose(std, 0):
                mean, std = 0, 1
        else:
            mean, std = 0, 1
        returns = np.zeros_like(indice)
        gammas = np.zeros_like(indice) + n_step
        done, buf_len = buffer.done, len(buffer)
        for n in range(n_step - 1, -1, -1):
            now = (indice + n) % buf_len
            gammas[done[now] > 0] = n
            returns[done[now] > 0] = 0
            returns = (rew[now] - mean) / std + gamma * returns
        terminal = (indice + n_step - 1) % buf_len
        target_q = target_q_fn(buffer, terminal).flatten()  # shape: [bsz, ]
        target_q[gammas != n_step] = 0
        returns = to_torch_as(returns, target_q)
        gammas = to_torch_as(gamma ** gammas, target_q)
        batch.returns = target_q * gammas + returns
        return batch
github thu-ml / tianshou / tianshou / policy / modelfree / sac.py View on Github external
def forward(self, batch: Batch,
                state: Optional[Union[dict, Batch, np.ndarray]] = None,
                input: str = 'obs',
                explorating: bool = True,
                **kwargs) -> Batch:
        obs = getattr(batch, input)
        logits, h = self.actor(obs, state=state, info=batch.info)
        assert isinstance(logits, tuple)
        dist = DiagGaussian(*logits)
        x = dist.rsample()
        y = torch.tanh(x)
        act = y * self._action_scale + self._action_bias
        y = self._action_scale * (1 - y.pow(2)) + self.__eps
        log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
        if self._noise is not None and self.training and explorating:
            act += to_torch_as(self._noise(act.shape), act)
        act = act.clamp(self._range[0], self._range[1])
        return Batch(
            logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
github thu-ml / tianshou / tianshou / policy / modelfree / sac.py View on Github external
def _target_q(self, buffer: ReplayBuffer,
                  indice: np.ndarray) -> torch.Tensor:
        batch = buffer[indice]  # batch.obs: s_{t+n}
        with torch.no_grad():
            obs_next_result = self(batch, input='obs_next', explorating=False)
            a_ = obs_next_result.act
            batch.act = to_torch_as(batch.act, a_)
            target_q = torch.min(
                self.critic1_old(batch.obs_next, a_),
                self.critic2_old(batch.obs_next, a_),
            ) - self._alpha * obs_next_result.log_prob
        return target_q