How to use the dowhy.do_samplers.get_class_object function in dowhy

To help you get started, we’ve selected a few dowhy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github microsoft / dowhy / dowhy / api / causal_data_frame.py View on Github external
outcome = parse_state(outcome)
        if not stateful or method != self._method:
            self.reset()
        if not self._causal_model:
            self._causal_model = CausalModel(self._obj,
                                             [xi for xi in x.keys()],
                                             outcome,
                                             graph=dot_graph,
                                             common_causes=common_causes,
                                             instruments=None,
                                             estimand_type=estimand_type,
                                             proceed_when_unidentifiable=proceed_when_unidentifiable)
        #self._identified_estimand = self._causal_model.identify_effect()
        if not self._sampler:
            self._method = method
            do_sampler_class = do_samplers.get_class_object(method + "_sampler")
            self._sampler = do_sampler_class(self._obj,
                                             #self._identified_estimand,
                                             #self._causal_model._treatment,
                                             #self._causal_model._outcome,
                                             params=params,
                                             variable_types=variable_types,
                                             num_cores=num_cores,
                                             causal_model=self._causal_model,
                                             keep_original_treatment=keep_original_treatment)
        result = self._sampler.do_sample(x)
        if not stateful:
            self.reset()
        return result