How to use the arviz.data.base.requires function in arviz

To help you get started, we’ve selected a few arviz examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github arviz-devs / arviz / arviz / data / io_cmdstan.py View on Github external
    @requires("prior")
    def prior_to_xarray(self):
        """Convert prior samples to xarray."""
        # filter prior_predictive
        prior_predictive = self.prior_predictive
        columns = self.prior[0].columns
        if prior_predictive is None or (
            isinstance(prior_predictive, str) and prior_predictive.lower().endswith(".csv")
        ):
            prior_predictive = []
        elif isinstance(prior_predictive, str):
            prior_predictive = [col for col in columns if prior_predictive == col.split(".")[0]]
        else:
            prior_predictive = [
                col
                for col in columns
                if any(item == col.split(".")[0] for item in prior_predictive)
github arviz-devs / arviz / arviz / data / io_cmdstanpy.py View on Github external
    @requires("posterior_predictive")
    def posterior_predictive_to_xarray(self):
        """Convert posterior_predictive samples to xarray."""
        posterior_predictive = self.posterior_predictive
        columns = self.posterior.column_names

        if isinstance(posterior_predictive, str):
            posterior_predictive = [posterior_predictive]
        valid_cols = [col for col in columns if col.split(".")[0] in set(posterior_predictive)]
        data = _unpack_frame(self.posterior.sample, columns, valid_cols)
        return dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims)
github arviz-devs / arviz / arviz / data / io_cmdstan.py View on Github external
    @requires("sample_stats")
    def sample_stats_to_xarray(self):
        """Extract sample_stats from fit."""
        dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}

        # copy dims and coords
        dims = deepcopy(self.dims) if self.dims is not None else {}
        coords = deepcopy(self.coords) if self.coords is not None else {}

        sampler_params = self.sample_stats
        log_likelihood = self.log_likelihood
        if isinstance(log_likelihood, str):
            log_likelihood_cols = [
                col for col in self.posterior[0].columns if log_likelihood == col.split(".")[0]
            ]
            log_likelihood_vals = [item[log_likelihood_cols] for item in self.posterior]
github arviz-devs / arviz / arviz / data / io_cmdstan.py View on Github external
    @requires("posterior")
    def posterior_to_xarray(self):
        """Extract posterior samples from output csv."""
        columns = self.posterior[0].columns

        # filter posterior_predictive and log_likelihood
        posterior_predictive = self.posterior_predictive
        if posterior_predictive is None or (
            isinstance(posterior_predictive, str) and posterior_predictive.lower().endswith(".csv")
        ):
            posterior_predictive = []
        elif isinstance(posterior_predictive, str):
            posterior_predictive = [
                col for col in columns if posterior_predictive == col.split(".")[0]
            ]
        else:
            posterior_predictive = [
github arviz-devs / arviz / arviz / data / io_pystan.py View on Github external
    @requires("prior_predictive")
    def prior_predictive_to_xarray(self):
        """Convert prior_predictive samples to xarray."""
        prior = self.prior
        prior_model = self.prior_model
        prior_predictive = self.prior_predictive
        data = get_draws_stan3(prior, model=prior_model, variables=prior_predictive)
        return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)
github arviz-devs / arviz / arviz / data / io_cmdstan.py View on Github external
    @requires("posterior_predictive")
    def posterior_predictive_to_xarray(self):
        """Convert posterior_predictive samples to xarray."""
        posterior_predictive = self.posterior_predictive
        columns = self.posterior[0].columns
        if (
            isinstance(posterior_predictive, (tuple, list))
            and posterior_predictive[0].endswith(".csv")
        ) or (isinstance(posterior_predictive, str) and posterior_predictive.endswith(".csv")):
            if isinstance(posterior_predictive, str):
                posterior_predictive = [posterior_predictive]
            chain_data = []
            for path in posterior_predictive:
                parsed_output = _read_output(path)
                for sample, *_ in parsed_output:
                    chain_data.append(sample)
            data = _unpack_dataframes(chain_data)
github arviz-devs / arviz / arviz / data / io_pymc3.py View on Github external
    @requires("trace")
    @requires("model")
    def constant_data_to_xarray(self):
        """Convert constant data to xarray."""
        model_vars = self.pymc3.util.get_default_varnames(  # pylint: disable=no-member
            self.trace.varnames, include_transformed=True
        )
        if self.observations is not None:
            model_vars.extend(
                [obs.name for obs in self.observations.values() if hasattr(obs, "name")]
            )
            model_vars.extend(self.observations.keys())
        constant_data_vars = {
            name: var for name, var in self.model.named_vars.items() if name not in model_vars
        }
        if not constant_data_vars:
            return None
github arviz-devs / arviz / arviz / data / io_dict.py View on Github external
    @requires("posterior")
    def posterior_to_xarray(self):
        """Convert posterior samples to xarray."""
        data = self.posterior
        if not isinstance(data, dict):
            raise TypeError("DictConverter.posterior is not a dictionary")

        if "log_likelihood" in data:
            warnings.warn(
                "log_likelihood found in posterior."
                " For stats functions log_likelihood needs to be in sample_stats.",
                SyntaxWarning,
            )

        return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
github arviz-devs / arviz / arviz / data / io_pymc3.py View on Github external
    @requires("trace")
    def sample_stats_to_xarray(self):
        """Extract sample_stats from PyMC3 trace."""
        rename_key = {"model_logp": "lp"}
        data = {}
        for stat in self.trace.stat_names:
            name = rename_key.get(stat, stat)
            data[name] = np.array(self.trace.get_sampler_stats(stat, combine=False))
        log_likelihood, dims = self._extract_log_likelihood()
        if log_likelihood is not None:
            data["log_likelihood"] = log_likelihood
            dims = {"log_likelihood": dims}
        else:
            dims = None

        return dict_to_dataset(data, library=self.pymc3, dims=dims, coords=self.coords)
github arviz-devs / arviz / arviz / data / io_pystan.py View on Github external
    @requires("prior")
    def sample_stats_prior_to_xarray(self):
        """Extract sample_stats_prior from prior."""
        prior = self.prior
        prior_model = self.prior_model
        data = get_sample_stats_stan3(prior, model=prior_model)
        return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)