How to use the dowhy.utils.api.parse_state function in dowhy

To help you get started, we’ve selected a few dowhy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github microsoft / dowhy / dowhy / do_sampler.py View on Github external
:param data: pandas.DataFrame containing the data
        :param identified_estimand: dowhy.causal_identifier.IdentifiedEstimand: and estimand using a backdoor method
        for effect identification.
        :param treatments: list or str:  names of the treatment variables
        :param outcomes: list or str: names of the outcome variables
        :param variable_types: dict: A dictionary containing the variable's names and types. 'c' for continuous, 'o'
        for ordered, 'd' for discrete, and 'u' for unordered discrete.
        :param keep_original_treatment: bool: Whether to use `make_treatment_effective`, or to keep the original
        treatment assignments.
        :param params: (optional) additional method parameters

        """
        self._data = data.copy()
        self._target_estimand = identified_estimand
        self._treatment_names = parse_state(treatments)
        self._outcome_names = parse_state(outcomes)
        self._estimate = None
        self._variable_types = variable_types
        self.num_cores = num_cores
        self.point_sampler = True
        self.sampler = None
        self.keep_original_treatment = keep_original_treatment

        if params is not None:
            for key, value in params.items():
                setattr(self, key, value)

        self._df = self._data.copy()

        if not self._variable_types:
            self._infer_variable_types()
github microsoft / dowhy / dowhy / causal_identifier.py View on Github external
def __init__(self, treatment_variable, outcome_variable,
                 estimand_type=None, estimands=None,
                 backdoor_variables=None, instrumental_variables=None):
        self.treatment_variable = parse_state(treatment_variable)
        self.outcome_variable = parse_state(outcome_variable)
        self.backdoor_variables = parse_state(backdoor_variables)
        self.instrumental_variables = parse_state(instrumental_variables)
        self.estimand_type = estimand_type
        self.estimands = estimands
        self.identifier_method = None
github microsoft / dowhy / dowhy / causal_model.py View on Github external
:param graph: path to DOT file containing a DAG or a string containing
        a DAG specification in DOT format
        :param common_causes: names of common causes of treatment and _outcome
        :param instruments: names of instrumental variables for the effect of
        treatment on outcome
        :param effect_modifiers: names of variables that can modify the treatment effect (useful for heterogeneous treatment effect estimation)
        :param estimand_type: the type of estimand requested (currently only "nonparametric-ate" is supported). In the future, may support other specific parametric forms of identification.
        :proceed_when_unidentifiable: does the identification proceed by ignoring potential unobserved confounders. Binary flag.
        :missing_nodes_as_confounders: Binary flag indicating whether variables in the dataframe that are not included in the causal graph, should be  automatically included as confounder nodes.

        :returns: an instance of CausalModel class

        """
        self._data = data
        self._treatment = parse_state(treatment)
        self._outcome = parse_state(outcome)
        self._estimand_type = estimand_type
        self._proceed_when_unidentifiable = proceed_when_unidentifiable
        self._missing_nodes_as_confounders = missing_nodes_as_confounders
        if 'logging_level' in kwargs:
            logging.basicConfig(level=kwargs['logging_level'])
        else:
            logging.basicConfig(level=logging.INFO)

        # TODO: move the logging level argument to a json file. Tue 20 Feb 2018 06:56:27 PM DST
        self.logger = logging.getLogger(__name__)

        if graph is None:
            self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
            self._common_causes = parse_state(common_causes)
            self._instruments = parse_state(instruments)
            self._effect_modifiers = parse_state(effect_modifiers)
github microsoft / dowhy / dowhy / api / causal_data_frame.py View on Github external
:param num_cores: int: if the inference method only supports sampling a point at a time, this will parallelize
        sampling.
        :param variable_types: dict: The dictionary containing the variable types. Must contain the union of the causal
        state, control variables, and the outcome.
        :param outcome: str: The outcome variable.
        :param params: dict: extra parameters to set as attributes on the sampler object
        :param dot_graph: str: A string specifying the causal graph.
        :param common_causes: list: A list of strings containing the variable names to control for.
        :param estimand_type: str: 'nonparametric-ate' is the only one currently supported. Others may be added later, to allow for specific, parametric estimands.
        :param proceed_when_unidentifiable: bool: A flag to over-ride user prompts to proceed when effects aren't
        identifiable with the assumptions provided.
        :param stateful: bool: Whether to retain state. By default, the do operation is stateless.
        :return: pandas.DataFrame: A DataFrame containing the sampled outcome
        """
        x, keep_original_treatment = self.parse_x(x)
        outcome = parse_state(outcome)
        if not stateful or method != self._method:
            self.reset()
        if not self._causal_model:
            self._causal_model = CausalModel(self._obj,
                                             [xi for xi in x.keys()],
                                             outcome,
                                             graph=dot_graph,
                                             common_causes=common_causes,
                                             instruments=None,
                                             estimand_type=estimand_type,
                                             proceed_when_unidentifiable=proceed_when_unidentifiable)
        #self._identified_estimand = self._causal_model.identify_effect()
        if not self._sampler:
            self._method = method
            do_sampler_class = do_samplers.get_class_object(method + "_sampler")
            self._sampler = do_sampler_class(self._obj,
github microsoft / dowhy / dowhy / causal_model.py View on Github external
self._estimand_type = estimand_type
        self._proceed_when_unidentifiable = proceed_when_unidentifiable
        self._missing_nodes_as_confounders = missing_nodes_as_confounders
        if 'logging_level' in kwargs:
            logging.basicConfig(level=kwargs['logging_level'])
        else:
            logging.basicConfig(level=logging.INFO)

        # TODO: move the logging level argument to a json file. Tue 20 Feb 2018 06:56:27 PM DST
        self.logger = logging.getLogger(__name__)

        if graph is None:
            self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
            self._common_causes = parse_state(common_causes)
            self._instruments = parse_state(instruments)
            self._effect_modifiers = parse_state(effect_modifiers)
            if common_causes is not None and instruments is not None:
                self._graph = CausalGraph(
                    self._treatment,
                    self._outcome,
                    common_cause_names=self._common_causes,
                    instrument_names=self._instruments,
                    effect_modifier_names = self._effect_modifiers,
                    observed_node_names=self._data.columns.tolist()
                )
            elif common_causes is not None:
                self._graph = CausalGraph(
                    self._treatment,
                    self._outcome,
                    common_cause_names=self._common_causes,
                    effect_modifier_names = self._effect_modifiers,
                    observed_node_names=self._data.columns.tolist()
github microsoft / dowhy / dowhy / causal_graph.py View on Github external
def __init__(self,
                 treatment_name, outcome_name,
                 graph=None,
                 common_cause_names=None,
                 instrument_names=None,
                 effect_modifier_names=None,
                 observed_node_names=None,
                 missing_nodes_as_confounders=False):
        self.treatment_name = parse_state(treatment_name)
        self.outcome_name = parse_state(outcome_name)
        instrument_names = parse_state(instrument_names)
        common_cause_names = parse_state(common_cause_names)
        effect_modifier_names = parse_state(effect_modifier_names)
        self.logger = logging.getLogger(__name__)

        if graph is None:
            self._graph = nx.DiGraph()
            self._graph = self.build_graph(common_cause_names,
                                           instrument_names, effect_modifier_names)
        elif re.match(r".*\.dot", graph):
            # load dot file
            try:
                import pygraphviz as pgv
                self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
            except Exception as e:
                self.logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...")
                try:
                    import pydot
                    self._graph = nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph))
github microsoft / dowhy / dowhy / causal_model.py View on Github external
self._outcome = parse_state(outcome)
        self._estimand_type = estimand_type
        self._proceed_when_unidentifiable = proceed_when_unidentifiable
        self._missing_nodes_as_confounders = missing_nodes_as_confounders
        if 'logging_level' in kwargs:
            logging.basicConfig(level=kwargs['logging_level'])
        else:
            logging.basicConfig(level=logging.INFO)

        # TODO: move the logging level argument to a json file. Tue 20 Feb 2018 06:56:27 PM DST
        self.logger = logging.getLogger(__name__)

        if graph is None:
            self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
            self._common_causes = parse_state(common_causes)
            self._instruments = parse_state(instruments)
            self._effect_modifiers = parse_state(effect_modifiers)
            if common_causes is not None and instruments is not None:
                self._graph = CausalGraph(
                    self._treatment,
                    self._outcome,
                    common_cause_names=self._common_causes,
                    instrument_names=self._instruments,
                    effect_modifier_names = self._effect_modifiers,
                    observed_node_names=self._data.columns.tolist()
                )
            elif common_causes is not None:
                self._graph = CausalGraph(
                    self._treatment,
                    self._outcome,
                    common_cause_names=self._common_causes,
                    effect_modifier_names = self._effect_modifiers,
github microsoft / dowhy / dowhy / causal_graph.py View on Github external
def __init__(self,
                 treatment_name, outcome_name,
                 graph=None,
                 common_cause_names=None,
                 instrument_names=None,
                 effect_modifier_names=None,
                 observed_node_names=None,
                 missing_nodes_as_confounders=False):
        self.treatment_name = parse_state(treatment_name)
        self.outcome_name = parse_state(outcome_name)
        instrument_names = parse_state(instrument_names)
        common_cause_names = parse_state(common_cause_names)
        effect_modifier_names = parse_state(effect_modifier_names)
        self.logger = logging.getLogger(__name__)

        if graph is None:
            self._graph = nx.DiGraph()
            self._graph = self.build_graph(common_cause_names,
                                           instrument_names, effect_modifier_names)
        elif re.match(r".*\.dot", graph):
            # load dot file
            try:
                import pygraphviz as pgv
                self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
            except Exception as e:
                self.logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...")
github microsoft / dowhy / dowhy / causal_graph.py View on Github external
def get_causes(self, nodes, remove_edges = None):
        nodes = parse_state(nodes)
        new_graph=None
        if remove_edges is not None:
            new_graph = self._graph.copy()  # caution: shallow copy of the attributes
            sources = parse_state(remove_edges["sources"])
            targets = parse_state(remove_edges["targets"])
            for s in sources:
                for t in targets:
                    new_graph.remove_edge(s, t)
        causes = set()
        for v in nodes:
            causes = causes.union(self.get_ancestors(v, new_graph=new_graph))
        return causes
github microsoft / dowhy / dowhy / causal_graph.py View on Github external
def get_common_causes(self, nodes1, nodes2):
        """
        Assume that nodes1 causes nodes2 (e.g., nodes1 are the treatments and nodes2 are the outcomes)
        """
        # TODO Refactor to remove this from here and only implement this logic in causalIdentifier. Unnecessary assumption of nodes1 to be causing nodes2.
        nodes1 = parse_state(nodes1)
        nodes2 = parse_state(nodes2)
        causes_1 = set()
        causes_2 = set()
        for node in nodes1:
            causes_1 = causes_1.union(self.get_ancestors(node))
        for node in nodes2:
            # Cannot simply compute ancestors, since that will also include nodes1 and its parents (e.g. instruments)
            parents_2 = self.get_parents(node)
            for parent in parents_2:
                if parent not in nodes1:
                    causes_2 = causes_2.union(set([parent,]))
                    causes_2 = causes_2.union(self.get_ancestors(parent))
        return list(causes_1.intersection(causes_2))