How to use dataclasses - 10 common examples

To help you get started, we’ve selected a few dataclasses examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github StanfordAHA / garnet / global_buffer / global_buffer_magma_helper.py View on Github external
import dataclasses
import magma as m


@dataclasses.dataclass(eq=True, frozen=True)
class GlobalBufferParams:
    # Tile parameters
    NUM_GLB_TILES: int = 16
    TILE_SEL_ADDR_WIDTH: int = m.bitutils.clog2(NUM_GLB_TILES)

    # CGRA Tiles
    NUM_CGRA_TILES: int = 32

    # CGRA tiles per GLB tile
    CGRA_PER_GLB: int = NUM_CGRA_TILES // NUM_GLB_TILES # 2

    # Bank parameters
    BANKS_PER_TILE: int = 2
    BANK_SEL_ADDR_WIDTH: int = m.bitutils.clog2(BANKS_PER_TILE)
    BANK_DATA_WIDTH: int = 64
    BANK_ADDR_WIDTH: int = 17
github crowsonkb / style_transfer / style_transfer.py View on Github external
features[layer] = np.zeros(shape, dtype=np.float32)
        for y in range(ntiles[0]):
            for x in range(ntiles[1]):
                xy = np.array([y, x])
                start = xy * tile_size
                end = start + tile_size
                if y == ntiles[0] - 1:
                    end[0] = img_size[0]
                if x == ntiles[1] - 1:
                    end[1] = img_size[1]
                tile = self.img[:, start[0]:end[0], start[1]:end[1]]
                pool.ensure_healthy()
                pool.request(FeatureMapRequest(start, SharedNDArray.copy(tile), layers))
        pool.reset_next_worker()
        for _ in range(np.prod(ntiles)):
            start, feats_tile = astuple(pool.resp_q.get())
            for layer, feat in feats_tile.items():
                scale, _ = self.layer_info(layer)
                start_f = start // scale
                end_f = start_f + np.array(feat.array.shape[-2:])
                features[layer][:, start_f[0]:end_f[0], start_f[1]:end_f[1]] = feat.array
                feat.unlink()

        return features
github samuelcolvin / pydantic / tests / test_dataclasses.py View on Github external
def test_initvars_post_init():
    @pydantic.dataclasses.dataclass
    class PathDataPostInit:
        path: Path
        base_path: dataclasses.InitVar[Optional[Path]] = None

        def __post_init__(self, base_path):
            if base_path is not None:
                self.path = base_path / self.path

    path_data = PathDataPostInit('world')
    assert 'path' in path_data.__dict__
    assert 'base_path' not in path_data.__dict__
    assert path_data.path == Path('world')

    with pytest.raises(TypeError) as exc_info:
        PathDataPostInit('world', base_path='/hello')
    assert str(exc_info.value) == "unsupported operand type(s) for /: 'str' and 'str'"
github pantsbuild / pants / tests / python / pants_test / util / test_meta.py View on Github external
def test_no_init(self) -> None:
    @frozen_after_init
    class Test:
      pass

    test = Test()
    with self.assertRaises(FrozenInstanceError):
      test.x = 1  # type: ignore[attr-defined]
github pantsbuild / pants / tests / python / pants_test / util / test_meta.py View on Github external
def test_add_new_field_after_init(self) -> None:
    @frozen_after_init
    class Test:

      def __init__(self, x: int) -> None:
        self.x = x

    test = Test(x=0)
    with self.assertRaises(FrozenInstanceError):
      test.y = "abc"  # type: ignore[attr-defined]
github omry / omegaconf / tests / structured_conf / data / dataclasses.py View on Github external
@dataclass
class ErrorDictIntKey:
    # invalid dict key, must be str
    dict: Dict[int, str] = field(default_factory=lambda: {10: "foo", 20: "bar"})


class RegularClass:
    pass


@dataclass
class ErrorDictUnsupportedValue:
    # invalid dict value type, not one of the supported types
    dict: Dict[str, RegularClass] = field(default_factory=dict)


@dataclass
class ErrorListUnsupportedValue:
    # invalid dict value type, not one of the supported types
    dict: List[RegularClass] = field(default_factory=list)


@dataclass
class ErrorListUnsupportedStructuredConfig:
    # Nesting of structured configs in Dict and List is not currently supported
    list: List[User] = field(default_factory=list)


@dataclass
class ListExamples:
github anchore / syft / test / inline-compare / utils / package.py View on Github external
import difflib
import collections
import dataclasses
from typing import Set, FrozenSet, Tuple, Any, List

Metadata = collections.namedtuple("Metadata", "version")
Package = collections.namedtuple("Package", "name type")
Info = collections.namedtuple("Info", "packages metadata")

SimilarPackages = collections.namedtuple("SimilarPackages", "pkg missed")
ProbableMatch = collections.namedtuple("ProbableMatch", "pkg ratio")


@dataclasses.dataclass()
class Analysis:
    """
    A package metadata analysis class. When given the raw syft and inline data, all necessary derivative information
    needed to do a comparison of package and metadata is performed, allowing callers to interpret the results
    """

    # all raw data from the inline scan and syft reports
    syft_data: Info
    inline_data: Info

    # all derivative information (derived from the raw data above)
    overlapping_packages: FrozenSet[Package] = dataclasses.field(init=False)
    extra_packages: FrozenSet[Package] = dataclasses.field(init=False)
    missing_packages: FrozenSet[Package] = dataclasses.field(init=False)

    inline_metadata: Set[Tuple[Any, Any]] = dataclasses.field(init=False)
github danieljfarrell / pvtrace / tests / test_tracer.py View on Github external
"""
        root = Node(name="Root", geometry=Sphere(radius=10.0))
        a = Node(name="A", parent=root, geometry=Sphere(radius=1.0))
        a.translate((5.0, 0.0, 0.0))
        scene = Scene(root)
        tracer = PhotonTracer(scene)
        position = (-2.0, 0.0, 0.0)
        direction = (1.0, 0.0, 0.0)
        initial_ray = Ray(
            position=position, direction=direction, wavelength=555.0, is_alive=True
        )
        expected_history = [
            initial_ray,  # Starting ray
            replace(initial_ray, position=(4.0, 0.0, 0.0)),  # First intersection
            replace(initial_ray, position=(6.0, 0.0, 0.0)),  # Second intersection
            replace(initial_ray, position=(10.0, 0.0, 0.0), is_alive=False),  # Exit ray
        ]
        history = tracer.follow(initial_ray)
        for pair in zip(history, expected_history):
            assert pair[0] == pair[1]
github dfurtado / dataclass-csv / tests / test_dataclass_reader.py View on Github external
def test_reader_values(create_csv):
    csv_file = create_csv(
        [{'name': 'User1', 'age': 40}, {'name': 'User2', 'age': 30}]
    )

    with csv_file.open() as f:
        reader = DataclassReader(f, User)
        items = list(reader)

        assert items and len(items) == 2

        for item in items:
            assert dataclasses.is_dataclass(item)

        user1, user2 = items[0], items[1]

        assert user1.name == 'User1'
        assert user1.age == 40

        assert user2.name == 'User2'
        assert user2.age == 30
github oxan / djangorestframework-dataclasses / tests / test_fields.py View on Github external
def build_typed_field(self, type_hint, extra_kwargs=None):
        testclass = dataclasses.make_dataclass('TestDataclass', [('test_field', type_hint)])
        serializer = DataclassSerializer(dataclass=testclass)
        type_info = field_utils.get_type_info(serializer.dataclass_definition.field_types['test_field'])

        extra_kwargs = extra_kwargs or {}
        return serializer.build_typed_field('test_field', type_info, extra_kwargs)