How to use the arviz.utils.expand_dims function in arviz

To help you get started, we’ve selected a few arviz examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github arviz-devs / arviz / arviz / data / io_pyro.py View on Github external
def posterior_predictive_to_xarray(self):
        """Convert posterior_predictive samples to xarray."""
        data = {}
        for k, ary in self.posterior_predictive.items():
            ary = ary.detach().cpu().numpy()
            shape = ary.shape
            if shape[0] == self.nchains and shape[1] == self.ndraws:
                data[k] = ary
            elif shape[0] == self.nchains * self.ndraws:
                data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
            else:
                data[k] = utils.expand_dims(ary)
                _log.warning(
                    "posterior predictive shape not compatible with number of chains and draws. "
                    "This can mean that some draws or even whole chains are not represented."
                )
        return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
github arviz-devs / arviz / arviz / data / io_tfp.py View on Github external
def handle_chain_location(self, ary):
        """Move the axis corresponding to the chain to first position.

        If there is only one chain which has no axis, add it.
        """
        if self.chain_dim is None:
            return utils.expand_dims(ary)
        return ary.swapaxes(0, self.chain_dim)
github arviz-devs / arviz / arviz / data / io_numpyro.py View on Github external
def posterior_predictive_to_xarray(self):
        """Convert posterior_predictive samples to xarray."""
        data = {}
        for k, ary in self.posterior_predictive.items():
            shape = ary.shape
            if shape[0] == self.nchains and shape[1] == self.ndraws:
                data[k] = ary
            elif shape[0] == self.nchains * self.ndraws:
                data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
            else:
                data[k] = utils.expand_dims(ary)
                _log.warning(
                    "posterior predictive shape not compatible with number of chains and draws. "
                    "This can mean that some draws or even whole chains are not represented."
                )
        return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)
github arviz-devs / arviz / arviz / data / io_pyro.py View on Github external
return {"prior": None, "prior_predictive": None}
        if self.posterior is not None:
            prior_vars = list(self.posterior.get_samples().keys())
            prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
        else:
            prior_vars = self.prior.keys()
            prior_predictive_vars = None
        priors_dict = {}
        for group, var_names in zip(
            ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
        ):
            priors_dict[group] = (
                None
                if var_names is None
                else dict_to_dataset(
                    {k: utils.expand_dims(self.prior[k].detach().cpu().numpy()) for k in var_names},
                    library=self.pyro,
                    coords=self.coords,
                    dims=self.dims,
                )
            )
        return priors_dict
github arviz-devs / arviz / arviz / data / io_numpyro.py View on Github external
return {"prior": None, "prior_predictive": None}
        if self.posterior is not None:
            prior_vars = list(self._samples.keys())
            prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
        else:
            prior_vars = self.prior.keys()
            prior_predictive_vars = None
        priors_dict = {}
        for group, var_names in zip(
            ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
        ):
            priors_dict[group] = (
                None
                if var_names is None
                else dict_to_dataset(
                    {k: utils.expand_dims(self.prior[k]) for k in var_names},
                    library=self.numpyro,
                    coords=self.coords,
                    dims=self.dims,
                )
            )
        return priors_dict