How to use the tianshou.data.batch.Batch function in tianshou

To help you get started, we’ve selected a few tianshou examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github thu-ml / tianshou / tianshou / data / utils.py View on Github external
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
             dtype: Optional[torch.dtype] = None,
             device: Union[str, int, torch.device] = 'cpu'
             ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
    """Return an object without np.ndarray."""
    if isinstance(x, torch.Tensor):
        if dtype is not None:
            x = x.type(dtype)
        x = x.to(device)
    elif isinstance(x, dict):
        for k, v in x.items():
            x[k] = to_torch(v, dtype, device)
    elif isinstance(x, Batch):
        x.to_torch(dtype, device)
    elif isinstance(x, (np.number, np.bool_, Number)):
        x = to_torch(np.asanyarray(x), dtype, device)
    elif isinstance(x, (list, tuple)):
        try:
            x = to_torch(_parse_value(x), dtype, device)
        except TypeError:
            x = [to_torch(e, dtype, device) for e in x]
    else:  # fallback
        x = np.asanyarray(x)
        if issubclass(x.dtype.type, (np.bool_, np.number)):
            x = torch.from_numpy(x).to(device)
            if dtype is not None:
                x = x.type(dtype)
        else:
            raise TypeError(f"object {x} cannot be converted to torch.")
github thu-ml / tianshou / tianshou / data / batch.py View on Github external
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
            >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
            >>> c = Batch.stack([a, b])
            >>> c.a.shape
            (2, 4, 4)
            >>> c.b.shape
            (2, 4, 6)
            >>> c.common.c.shape
            (2, 4, 5)

        .. note::

            If there are keys that are not shared across all batches, ``stack``
            with ``axis != 0`` is undefined, and will cause an exception.
        """
        batch = Batch()
        batch.stack_(batches, axis)
        return batch
github thu-ml / tianshou / tianshou / data / batch.py View on Github external
pass
    else:
        if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
                len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
            try:
                return torch.stack(v)
            except RuntimeError as e:
                raise TypeError("Batch does not support non-stackable iterable"
                                " of torch.Tensor as unique value yet.") from e
        try:
            v_ = _to_array_with_correct_type(v)
        except ValueError as e:
            raise TypeError("Batch does not support heterogeneous list/tuple"
                            " of tensors as unique value yet.") from e
        if _is_batch_set(v):
            v = Batch(v)  # list of dict / Batch
        else:
            # None, scalar, normal data list (main case)
            # or an actual list of objects
            v = v_
    return v
github thu-ml / tianshou / tianshou / data / batch.py View on Github external
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
    if isinstance(inst, np.ndarray):
        if issubclass(inst.dtype.type, (np.bool_, np.number)):
            target_type = inst.dtype.type
        else:
            target_type = np.object
        return np.full(shape,
                       fill_value=None if target_type == np.object else 0,
                       dtype=target_type)
    elif isinstance(inst, torch.Tensor):
        return torch.full(shape,
                          fill_value=0,
                          device=inst.device,
                          dtype=inst.dtype)
    elif isinstance(inst, (dict, Batch)):
        zero_batch = Batch()
        for key, val in inst.items():
            zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
        return zero_batch
    elif is_scalar:
        return _create_value(np.asarray(inst), size, stack=stack)
    else:  # fall back to np.object
        return np.array([None for _ in range(size)])
github thu-ml / tianshou / tianshou / data / buffer.py View on Github external
def __init__(self, size: int, stack_num: Optional[int] = 0,
                 ignore_obs_next: bool = False,
                 sample_avail: bool = False, **kwargs) -> None:
        super().__init__()
        self._maxsize = size
        self._stack = stack_num
        assert stack_num != 1, 'stack_num should greater than 1'
        self._avail = sample_avail and stack_num > 1
        self._avail_index = []
        self._save_s_ = not ignore_obs_next
        self._index = 0
        self._size = 0
        self._meta = Batch()
        self.reset()
github thu-ml / tianshou / tianshou / data / buffer.py View on Github external
indice += 1 - self.done[indice].astype(np.int)
            indice[indice == self._size] = 0
            key = 'obs'
        val = self._meta.__dict__[key]
        try:
            if stack_num > 0:
                stack = []
                for _ in range(stack_num):
                    stack = [val[indice]] + stack
                    pre_indice = np.asarray(indice - 1)
                    pre_indice[pre_indice == -1] = self._size - 1
                    indice = np.asarray(
                        pre_indice + self.done[pre_indice].astype(np.int))
                    indice[indice == self._size] = 0
                if isinstance(val, Batch):
                    stack = Batch.stack(stack, axis=indice.ndim)
                else:
                    stack = np.stack(stack, axis=indice.ndim)
            else:
                stack = val[indice]
        except IndexError as e:
            stack = Batch()
            if not isinstance(val, Batch) or len(val.__dict__) > 0:
                raise e
        self.done[last_index] = last_done
        return stack
github thu-ml / tianshou / tianshou / data / buffer.py View on Github external
stack = []
                for _ in range(stack_num):
                    stack = [val[indice]] + stack
                    pre_indice = np.asarray(indice - 1)
                    pre_indice[pre_indice == -1] = self._size - 1
                    indice = np.asarray(
                        pre_indice + self.done[pre_indice].astype(np.int))
                    indice[indice == self._size] = 0
                if isinstance(val, Batch):
                    stack = Batch.stack(stack, axis=indice.ndim)
                else:
                    stack = np.stack(stack, axis=indice.ndim)
            else:
                stack = val[indice]
        except IndexError as e:
            stack = Batch()
            if not isinstance(val, Batch) or len(val.__dict__) > 0:
                raise e
        self.done[last_index] = last_done
        return stack