How to use the dataclasses.replace function in dataclasses

To help you get started, we’ve selected a few dataclasses examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github danieljfarrell / pvtrace / tests / test_tracer.py View on Github external
"""
        root = Node(name="Root", geometry=Sphere(radius=10.0))
        a = Node(name="A", parent=root, geometry=Sphere(radius=1.0))
        a.translate((5.0, 0.0, 0.0))
        scene = Scene(root)
        tracer = PhotonTracer(scene)
        position = (-2.0, 0.0, 0.0)
        direction = (1.0, 0.0, 0.0)
        initial_ray = Ray(
            position=position, direction=direction, wavelength=555.0, is_alive=True
        )
        expected_history = [
            initial_ray,  # Starting ray
            replace(initial_ray, position=(4.0, 0.0, 0.0)),  # First intersection
            replace(initial_ray, position=(6.0, 0.0, 0.0)),  # Second intersection
            replace(initial_ray, position=(10.0, 0.0, 0.0), is_alive=False),  # Exit ray
        ]
        history = tracer.follow(initial_ray)
        for pair in zip(history, expected_history):
            assert pair[0] == pair[1]
github facebookresearch / ReAgent / ml / rl / readers / data_streamer.py View on Github external
def pin_memory(batch):
    """
    This is ripped off from dataloader. The only difference is that it preserves
    the type of Mapping so that the OrderedDict is maintained.
    """
    if isinstance(batch, torch.Tensor):
        return batch.pin_memory().cuda(non_blocking=True)
    elif isinstance(batch, string_classes):
        return batch
    elif dataclasses.is_dataclass(batch):
        return dataclasses.replace(
            batch,
            **{
                field.name: pin_memory(getattr(batch, field.name))
                for field in dataclasses.fields(batch)
            }
        )
    elif isinstance(batch, collections.Mapping):
        # NB: preserving OrderedDict
        return type(batch)((k, pin_memory(sample)) for k, sample in batch.items())
    elif isinstance(batch, NamedTuple) or hasattr(batch, "_asdict"):
        # This is mainly for WorkerDone
        return type(batch)(
            **{name: pin_memory(value) for name, value in batch._asdict().items()}
        )
    elif isinstance(batch, collections.Sequence):
        return [pin_memory(sample) for sample in batch]
github danieljfarrell / pvtrace / pvtrace / material / mechanisms.py View on Github external
def transform(self, ray: Ray, context: dict) -> Ray:
        _check_required_keys(set(["normal"]), context)
        normal = context["normal"]
        # check for angle > 90
        if np.dot(normal, ray.direction) < 0.0:
            normal = flip(normal)
        distance = 2*EPS_ZERO
        new_position = np.array(ray.position) + distance * np.array(normal)
        new_ray = replace(ray, position=new_position)
        return new_ray
github goldmansachs / gs-quant / gs_quant / api / fred / data.py View on Github external
def query_data(self, query: FredQuery, dataset_id: str, asset_id_type: str = None) -> pd.Series:
        """
        Query data given a valid FRED series id and url. Will raise an HTTPError if the response was an HTTP error.

        :param query: A url string of the requested data
        :param id: A FRED series id
        :return: with id as key and requested DataFrame as value.
        """
        request = replace(query, api_key=self.api_key, series_id=dataset_id)
        response = handle_proxy(self.root_url, asdict(request))
        handled = self.__handle_response(response)
        handled.name = dataset_id
        return handled
github seandstewart / typical / typic / constraints / factory.py View on Github external
def get_constraints(
    t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> ConstraintsT:
    while should_unwrap(t):
        nullable = nullable or isoptionaltype(t)
        t = get_args(t)[0]
    if isconstrained(t):
        c: ConstraintsT = t.__constraints__  # type: ignore
        if (c.name, c.nullable) != (name, nullable):
            return dataclasses.replace(c, name=name, nullable=nullable)
        return c
    if issubclass(t, enum.Enum):
        return _from_enum_type(t, nullable=nullable, name=name)  # type: ignore
    if isnamedtuple(t) or istypeddict(t):
        handler = _from_class
    else:
        handler = _CONSTRAINT_BUILDER_HANDLERS.get_by_parent(origin(t), _from_class)  # type: ignore
    c = handler(t, nullable=nullable, name=name, cls=cls)  # type: ignore
    return c
github googleapis / gapic-generator-python / gapic / schema / metadata.py View on Github external
def child(self, child_name: str, path: Tuple[int, ...]) -> 'Address':
        """Return a new child of the current Address.

        Args:
            child_name (str): The name of the child node.
                This address' name is appended to ``parent``.

        Returns:
            ~.Address: The new address object.
        """
        return dataclasses.replace(
            self,
            module_path=self.module_path + path,
            name=child_name,
            parent=self.parent + (self.name,) if self.name else self.parent,
        )
github getsentry / snuba / snuba / query / expressions.py View on Github external
def transform(self, func: Callable[[Expression], Expression]) -> Expression:
        """
        Transforms the subtree starting from the children and then applying
        the transformation function to the root.
        This order is chosen to make the semantics of transform more meaningful,
        the transform operation will be performed on the children first (think
        about the parameters of a function call) and then to the node itself.

        The consequence of this is that, if the transformation function replaces
        the root with something else, with different children, we trust the
        transformation function and we do not run that same function over the
        new children.
        """
        transformed = replace(
            self,
            parameters=tuple(map(lambda child: child.transform(func), self.parameters)),
        )
        return func(transformed)
github Datatamer / tamr-client / tamr_client / attribute / _attribute.py View on Github external
def _create(
    session: Session,
    dataset: Dataset,
    *,
    name: str,
    is_nullable: bool,
    type: AttributeType = attribute_type.DEFAULT,
    description: Optional[str] = None,
) -> Attribute:
    """Same as `tc.attribute.create`, but does not check for reserved attribute
    names.
    """
    attrs_url = replace(dataset.url, path=dataset.url.path + "/attributes")
    url = replace(attrs_url, path=attrs_url.path + f"/{name}")

    body = {
        "name": name,
        "type": attribute_type.to_json(type),
        "isNullable": is_nullable,
    }
    if description is not None:
        body["description"] = description

    r = session.post(str(attrs_url), json=body)
    if r.status_code == 409:
        raise AlreadyExists(str(url))
    data = response.successful(r).json()

    return _from_json(url, data)
github rahzaazhar / PAN-PSEnet / Recognition / Config.py View on Github external
#training baseline Bangla model CNN+LSTM
C4 = Config(experiment_name = 'Bangla_Baseline_Test', exp_dir = 'Experiments', train_data = '/content/drive/My Drive/data/training', valid_data = '/content/drive/My Drive/data/validation', langs = ['arab'], pli = [1000], mode = ['train'])
C4_val = replace(C4,mode='val')

C2 = replace(C1, experiment_name = 'TestDataclass1', langs = ['hin','ban','arab'], pli = [1000,1000,1000], mode = ['train','train','train'])
C3_test = replace(C2, experiment_name = 'ABH(CNN)', pli = [2,2,2], share = 'CNN', total_data_usage_ratio = '0.05')
C3_tst_val = replace(C3_test, mode = ['val','val','val'])
'''

C_subnet_ban = Config(experiment_name = 'Gen_Ban_Subnet1', exp_dir='Experiments', train_data = 'training', valid_data = 'validation', langs = ['ban'], pli = [1000], mode=['train'], num_iter = 100, task_id = [1])
C_subnet_ban_val = replace(C_subnet_ban,mode=['val'])
P_subnet_ban = PruneConfig()

C_subnet_ban = Config(experiment_name = 'Gen_Ban_Subnet', exp_dir='Experiments', train_data = '/content/drive/My Drive/data/training', valid_data = '/content/drive/My Drive/data/validation', langs = ['ban'], pli = [1000], mode=['train'], num_iter = 6000, task_id = [1])
C_subnet_ban_val = replace(C_subnet_ban,mode=['val'])

C_subnet_arab = Config(experiment_name = 'Gen_Arab_Subnet', exp_dir='Experiments', train_data = '/content/drive/My Drive/data/training', valid_data = '/content/drive/My Drive/data/validation', langs = ['arab'], pli = [1000], mode=['train'], num_iter = 6000, task_id = [0])
C_subnet_arab_val = replace(C_subnet_arab,mode=['val'])

C_subnet_hin = Config(experiment_name = 'Gen_Hin_Subnet', exp_dir='Experiments', train_data = '/content/drive/My Drive/data/training', valid_data = '/content/drive/My Drive/data/validation', langs = ['hin'], pli = [1000], mode=['train'], num_iter = 6000, task_id = [2])
P_subnet_ban = PruneConfig()
github codeforpdx / recordexpungPDX / src / backend / expungeservice / record_creator.py View on Github external
def _analyze_ambiguous_record(ambiguous_record: AmbiguousRecord, charge_ids_with_question: List[str]):
        charge_id_to_time_eligibilities = []
        ambiguous_record_with_errors = []
        for record in ambiguous_record:
            charge_id_to_time_eligibility = Expunger.run(record)
            charge_id_to_time_eligibilities.append(charge_id_to_time_eligibility)
            ambiguous_record_with_errors.append(record)
        record = RecordMerger.merge(
            ambiguous_record_with_errors, charge_id_to_time_eligibilities, charge_ids_with_question
        )
        sorted_record = RecordCreator.sort_record(record)
        return replace(sorted_record, errors=tuple(ErrorChecker.check(sorted_record)))