How to use the arviz.utils.one_de function in arviz

To help you get started, we’ve selected a few arviz examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github arviz-devs / arviz / arviz / data / io_numpyro.py View on Github external
def observed_data_to_xarray(self):
        """Convert observed data to xarray."""
        if self.dims is None:
            dims = {}
        else:
            dims = self.dims
        observed_data = {}
        for name, vals in self.observations.items():
            vals = utils.one_de(vals)
            val_dims = dims.get(name)
            val_dims, coords = generate_dims_coords(
                vals.shape, name, dims=val_dims, coords=self.coords
            )
            # filter coords based on the dims
            coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
            observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
        return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.numpyro))
github arviz-devs / arviz / arviz / data / io_emcee.py View on Github external
ndim = sampler.chain.shape[-1]
        num_args = len(sampler.args)
    elif hasattr(sampler, "log_prob_fn"):
        ndim = sampler.get_chain().shape[-1]
        num_args = len(sampler.log_prob_fn.args)
    else:
        ndim = sampler.get_chain().shape[-1]
        num_args = 0  # emcee only stores the posterior samples

    if slices is None:
        slices = utils.arange(ndim)
        num_vars = ndim
    else:
        num_vars = len(slices)
    indexs = utils.arange(ndim)
    slicing_try = np.concatenate([utils.one_de(indexs[idx]) for idx in slices])
    if len(set(slicing_try)) != ndim:
        warnings.warn(
            "Check slices: Not all parameters in chain captured. "
            "{} are present, and {} have been captured.".format(ndim, len(slicing_try)),
            SyntaxWarning,
        )
    if len(slicing_try) != len(set(slicing_try)):
        warnings.warn(
            "Overlapping slices. Check the index present: {}".format(slicing_try), SyntaxWarning
        )

    if var_names is None:
        var_names = ["var_{}".format(idx) for idx in range(num_vars)]
    if arg_names is None:
        arg_names = ["arg_{}".format(idx) for idx in range(num_args)]
github arviz-devs / arviz / arviz / data / io_pymc3.py View on Github external
def log_likelihood_vals_point(point):
            """Compute log likelihood for each observed point."""
            log_like_vals = []
            for var, log_like in cached:
                log_like_val = utils.one_de(log_like(point))
                if var.missing_values:
                    log_like_val = log_like_val[~var.observations.mask]
                log_like_vals.append(log_like_val)
            return np.concatenate(log_like_vals)
github arviz-devs / arviz / arviz / data / io_cmdstan.py View on Github external
def observed_data_to_xarray(self):
        """Convert observed data to xarray."""
        observed_data_raw = _read_data(self.observed_data)
        variables = self.observed_data_var
        if isinstance(variables, str):
            variables = [variables]
        observed_data = {}
        for key, vals in observed_data_raw.items():
            if variables is not None and key not in variables:
                continue
            vals = utils.one_de(vals)
            val_dims = self.dims.get(key)
            val_dims, coords = generate_dims_coords(
                vals.shape, key, dims=val_dims, coords=self.coords
            )
            observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
        return xr.Dataset(data_vars=observed_data)
github arviz-devs / arviz / arviz / data / io_pymc3.py View on Github external
def observed_data_to_xarray(self):
        """Convert observed data to xarray."""
        if self.dims is None:
            dims = {}
        else:
            dims = self.dims
        observed_data = {}
        for name, vals in self.observations.items():
            if hasattr(vals, "get_value"):
                vals = vals.get_value()
            vals = utils.one_de(vals)
            val_dims = dims.get(name)
            val_dims, coords = generate_dims_coords(
                vals.shape, name, dims=val_dims, coords=self.coords
            )
            # filter coords based on the dims
            coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
            observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
        return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.pymc3))