How to use the dataclasses.astuple function in dataclasses

To help you get started, we’ve selected a few dataclasses examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github crowsonkb / style_transfer / style_transfer.py View on Github external
features[layer] = np.zeros(shape, dtype=np.float32)
        for y in range(ntiles[0]):
            for x in range(ntiles[1]):
                xy = np.array([y, x])
                start = xy * tile_size
                end = start + tile_size
                if y == ntiles[0] - 1:
                    end[0] = img_size[0]
                if x == ntiles[1] - 1:
                    end[1] = img_size[1]
                tile = self.img[:, start[0]:end[0], start[1]:end[1]]
                pool.ensure_healthy()
                pool.request(FeatureMapRequest(start, SharedNDArray.copy(tile), layers))
        pool.reset_next_worker()
        for _ in range(np.prod(ntiles)):
            start, feats_tile = astuple(pool.resp_q.get())
            for layer, feat in feats_tile.items():
                scale, _ = self.layer_info(layer)
                start_f = start // scale
                end_f = start_f + np.array(feat.array.shape[-2:])
                features[layer][:, start_f[0]:end_f[0], start_f[1]:end_f[1]] = feat.array
                feat.unlink()

        return features
github mscarey / AuthoritySpoke / authorityspoke / entities.py View on Github external
def means(self, other):
        """
        ``generic`` :class:`Entity` objects are considered equal
        as long as they're the same class. If not ``generic``, they're
        considered equal if all their attributes are the same.
        """

        if self.__class__ != other.__class__:
            return False
        if self.generic and other.generic:
            return True
        return astuple(self) == astuple(other)
github espnet / espnet / espnet2 / utils / device_funcs.py View on Github external
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
    """Change the device of object recursively"""
    if isinstance(data, dict):
        return {
            k: to_device(v, device, dtype, non_blocking, copy)
            for k, v in data.items()
        }
    elif dataclasses.is_dataclass(data) and not isinstance(data, type):
        return type(data)(
            *[
                to_device(v, device, dtype, non_blocking, copy)
                for v in dataclasses.astuple(data)
            ]
        )
    # maybe namedtuple. I don't know the correct way to judge namedtuple.
    elif isinstance(data, tuple) and type(data) is not tuple:
        return type(data)(
            *[to_device(o, device, dtype, non_blocking, copy) for o in data]
        )
    elif isinstance(data, (list, tuple)):
        return type(data)(
            to_device(v, device, dtype, non_blocking, copy) for v in data
        )
    elif isinstance(data, np.ndarray):
        return to_device(
            torch.from_numpy(data), device, dtype, non_blocking, copy
        )
    elif isinstance(data, torch.Tensor):
github byceps / byceps / byceps / services / party / service.py View on Github external
def _db_entity_to_party_with_brand(party_entity: DbParty) -> PartyWithBrand:
    party = _db_entity_to_party(party_entity)
    brand = brand_service._db_entity_to_brand(party_entity.brand)

    return PartyWithBrand(*dataclasses.astuple(party), brand=brand)
github zalando-incubator / transformer / update-version.py View on Github external
def increment_version(v: Version, incr: Increment) -> Version:
    values = list(dataclasses.astuple(v))
    values[incr.value] += 1
    for i in range(incr.value + 1, max(Increment).value + 1):
        values[i] = 0
    return Version(*values)
github mscarey / AuthoritySpoke / authorityspoke / entities.py View on Github external
def means(self, other):
        """
        Test whether ``other`` has the same meaning as ``self``.

        ``Generic`` :class:`Entity` objects are considered equivalent
        in meaning as long as they're the same class. If not ``generic``,
        they're considered equivalent if all their attributes are the same.
        """

        if self.__class__ != other.__class__:
            return False
        if self.generic and other.generic:
            return True
        return astuple(self) == astuple(other)
github spacetx / starfish / starfish / core / experiment / builder / __init__.py View on Github external
def reducer_to_sets(
            accumulated: Sequence[MutableSet[int]], update: TileIdentifier,
    ) -> Sequence[MutableSet[int]]:
        """Reduces to a list of sets of tile identifiers, in the order of FOV, round, ch, and
        zplane."""
        result: MutableSequence[MutableSet[int]] = list()
        for accumulated_elem, update_elem in zip(accumulated, astuple(update)):
            accumulated_elem.add(update_elem)
            result.append(accumulated_elem)
        return result
    initial_value: Sequence[MutableSet[int]] = tuple(set() for _ in range(4))
github spacetx / starfish / starfish / core / experiment / builder / __init__.py View on Github external
Coordinates.X,
                Coordinates.Y,
                Coordinates.Z,
                Axes.ZPLANE,
                Axes.ROUND,
                Axes.CH,
                Axes.X,
                Axes.Y,
            ],
            {Axes.ROUND: len(rounds), Axes.CH: len(chs), Axes.ZPLANE: len(zplanes)},
            default_shape,
            ImageFormat.TIFF,
        )

        for tile_identifier in tile_identifiers:
            current_fov, current_round, current_ch, current_zplane = astuple(tile_identifier)
            # filter out the fovs that are not the one we are currently processing
            if expected_fov != current_fov:
                continue
            image = image_fetcher.get_tile(
                current_fov,
                current_round,
                current_ch,
                current_zplane
            )
            for axis in (Axes.X, Axes.Y):
                if image.shape[axis] < max(len(rounds), len(chs), len(zplanes)):
                    warnings.warn(
                        f"{axis} axis appears to be smaller than rounds/chs/zplanes, which is "
                        "unusual",
                        DataFormatWarning
                    )
github jseppanen / azalea / azalea / search_tree.py View on Github external
def step(self, move_id: int) -> None:
        """Move one step forward in game and tree.
        :param move_id: Ordinal move id among legal moves in current node
        """
        state = self._game.state
        assert move_id < len(state.legal_moves), \
            f"move id {move_id} out of range"
        assert not state.result, 'game ended in select'
        move = state.legal_moves[move_id]
        self._game.step(move)
        self._node = self._node.child(move_id)
        if DEBUG:
            print(f'step {move}: node {self._node.id} '
                  f'nch {self._node.num_children} '