How to use the arviz.data.convert_to_dataset function in arviz

To help you get started, we’ve selected a few arviz examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github arviz-devs / arviz / arviz / plots / rankplot.py View on Github external
>>> az.plot_rank(data, var_names='tau')

    Use vlines to compare results for centered vs noncentered models

    .. plot::
        :context: close-figs

        >>> import arviz as az
        >>> centered_data = az.load_arviz_data('centered_eight')
        >>> noncentered_data = az.load_arviz_data('non_centered_eight')
        >>> _, ax = plt.subplots(1, 2, figsize=(12, 3))
        >>> az.plot_rank(centered_data, var_names="mu", kind='vlines', axes=ax[0])
        >>> az.plot_rank(noncentered_data, var_names="mu", kind='vlines', axes=ax[1])

    """
    posterior_data = convert_to_dataset(data, group="posterior")
    if coords is not None:
        posterior_data = posterior_data.sel(**coords)
    var_names = _var_names(var_names, posterior_data)
    plotters = filter_plotters_list(
        list(xarray_var_iter(posterior_data, var_names=var_names, combined=True)), "plot_rank"
    )
    length_plotters = len(plotters)

    if bins is None:
        bins = _sturges_formula(posterior_data, mult=2)

    rows, cols = default_grid(length_plotters)
    if axes is None:
        figsize, ax_labelsize, titlesize, _, _, _ = _scale_fig_size(
            figsize, None, rows=rows, cols=cols
        )
github arviz-devs / arviz / arviz / plots / posteriorplot.py View on Github external
Plot posterior as a histogram

    .. plot::
        :context: close-figs

        >>> az.plot_posterior(data, var_names=['mu'], kind='hist')

    Change size of credible interval

    .. plot::
        :context: close-figs

        >>> az.plot_posterior(data, var_names=['mu'], credible_interval=.75)
    """
    data = convert_to_dataset(data, group=group)
    var_names = _var_names(var_names, data)

    if coords is None:
        coords = {}

    plotters = filter_plotters_list(
        list(xarray_var_iter(get_coords(data, coords), var_names=var_names, combined=True)),
        "plot_posterior",
    )
    length_plotters = len(plotters)
    rows, cols = default_grid(length_plotters)

    (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _) = _scale_fig_size(
        figsize, textsize, rows, cols
    )
    kwargs.setdefault("linewidth", _linewidth)
github arviz-devs / arviz / arviz / plots / pairplot.py View on Github external
divergences_kwargs.setdefault("lw", 0)

    # Get posterior draws and combine chains
    data = convert_to_inference_data(data)
    posterior_data = convert_to_dataset(data, group="posterior")
    var_names = _var_names(var_names, posterior_data)
    flat_var_names, _posterior = xarray_to_ndarray(
        get_coords(posterior_data, coords), var_names=var_names, combined=True
    )

    divergent_data = None
    diverging_mask = None
    # Get diverging draws and combine chains
    if divergences:
        if hasattr(data, "sample_stats") and hasattr(data.sample_stats, "diverging"):
            divergent_data = convert_to_dataset(data, group="sample_stats")
            _, diverging_mask = xarray_to_ndarray(
                divergent_data, var_names=("diverging",), combined=True
            )
            diverging_mask = np.squeeze(diverging_mask)
        else:
            divergences = False
            warnings.warn(
                "Divergences data not found, plotting without divergences. "
                "Make sure the sample method provides divergences data and "
                "that it is present in the `diverging` field of `sample_stats` "
                "or set divergences=False",
                SyntaxWarning,
            )

    if gridsize == "auto":
        gridsize = int(len(_posterior[0]) ** 0.35)
github arviz-devs / arviz / arviz / plots / pairplot.py View on Github external
if kind == "scatter":
        plot_kwargs.setdefault("marker", ".")
        plot_kwargs.setdefault("lw", 0)

    if divergences_kwargs is None:
        divergences_kwargs = {}

    divergences_kwargs.setdefault("marker", "o")
    divergences_kwargs.setdefault("markeredgecolor", "k")
    divergences_kwargs.setdefault("color", "C1")
    divergences_kwargs.setdefault("lw", 0)

    # Get posterior draws and combine chains
    data = convert_to_inference_data(data)
    posterior_data = convert_to_dataset(data, group="posterior")
    var_names = _var_names(var_names, posterior_data)
    flat_var_names, _posterior = xarray_to_ndarray(
        get_coords(posterior_data, coords), var_names=var_names, combined=True
    )

    divergent_data = None
    diverging_mask = None
    # Get diverging draws and combine chains
    if divergences:
        if hasattr(data, "sample_stats") and hasattr(data.sample_stats, "diverging"):
            divergent_data = convert_to_dataset(data, group="sample_stats")
            _, diverging_mask = xarray_to_ndarray(
                divergent_data, var_names=("diverging",), combined=True
            )
            diverging_mask = np.squeeze(diverging_mask)
        else:
github arviz-devs / arviz / arviz / plots / densityplot.py View on Github external
.. plot::
        :context: close-figs

        >>> az.plot_density([centered, non_centered], var_names=["mu"], outline=False, shade=.8)

    Specify binwidth for kernel density estimation

    .. plot::
        :context: close-figs

        >>> az.plot_density([centered, non_centered], var_names=["mu"], bw=.9)
    """
    if not isinstance(data, (list, tuple)):
        datasets = [convert_to_dataset(data, group=group)]
    else:
        datasets = [convert_to_dataset(datum, group=group) for datum in data]

    var_names = _var_names(var_names, datasets)

    if point_estimate not in ("mean", "median", None):
        raise ValueError(
            "Point estimate should be 'mean'," "median' or None, not {}".format(point_estimate)
        )

    n_data = len(datasets)

    if data_labels is None:
        if n_data > 1:
            data_labels = ["{}".format(idx) for idx in range(n_data)]
        else:
            data_labels = [""]
    elif len(data_labels) != n_data:
github arviz-devs / arviz / arviz / plots / forestplot.py View on Github external
>>> axes = az.plot_forest(non_centered_data,
        >>>                            kind='ridgeplot',
        >>>                            var_names=['theta'],
        >>>                            combined=True,
        >>>                            ridgeplot_overlap=3,
        >>>                            colors='white',
        >>>                            figsize=(9, 7))
        >>> axes[0].set_title('Estimated theta for 8 schools model')
    """
    if not isinstance(data, (list, tuple)):
        data = [data]

    if coords is None:
        coords = {}
    datasets = get_coords(
        [convert_to_dataset(datum) for datum in reversed(data)],
        list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords,
    )

    var_names = _var_names(var_names, datasets)

    ncols, width_ratios = 1, [3]

    if ess:
        ncols += 1
        width_ratios.append(1)

    if r_hat:
        ncols += 1
        width_ratios.append(1)

    plot_forest_kwargs = dict(
github arviz-devs / arviz / arviz / plots / traceplot.py View on Github external
try:
            divergence_data = convert_to_dataset(data, group="sample_stats").diverging
        except (ValueError, AttributeError):  # No sample_stats, or no `.diverging`
            divergences = False

    if coords is None:
        coords = {}

    if divergences:
        divergence_data = get_coords(
            divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")}
        )
    else:
        divergence_data = False

    data = get_coords(convert_to_dataset(data, group="posterior"), coords)
    var_names = _var_names(var_names, data)

    if lines is None:
        lines = ()

    num_colors = len(data.chain) + 1 if combined else len(data.chain)

    # TODO: matplotlib is always required by arviz. Can we get rid of it?
    colors = [
        prop
        for _, prop in zip(
            range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
        )
    ]

    if compact:
github arviz-devs / arviz / arviz / plots / essplot.py View on Github external
...     color="royalblue", extra_kwargs=extra_kwargs
        ... )

    """
    valid_kinds = ("local", "quantile", "evolution")
    kind = kind.lower()
    if kind not in valid_kinds:
        raise ValueError("Invalid kind, kind must be one of {} not {}".format(valid_kinds, kind))

    if coords is None:
        coords = {}
    if "chain" in coords or "draw" in coords:
        raise ValueError("chain and draw are invalid coordinates for this kind of plot")
    extra_methods = False if kind == "evolution" else extra_methods

    data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
    var_names = _var_names(var_names, data)
    n_draws = data.dims["draw"]
    n_samples = n_draws * data.dims["chain"]

    ess_tail_dataset = None
    mean_ess = None
    sd_ess = None
    text_x = None
    text_va = None

    if kind == "quantile":
        probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points)
        xdata = probs
        ylabel = "{} for quantiles"
        ess_dataset = xr.concat(
            [
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
)
        )
    rhat_func = methods[method]

    if isinstance(data, np.ndarray):
        data = np.atleast_2d(data)
        if len(data.shape) < 3:
            return rhat_func(data)
        else:
            msg = (
                "Only uni-dimensional ndarray variables are supported."
                " Please transform first to dataset with `az.convert_to_dataset`."
            )
            raise TypeError(msg)

    dataset = convert_to_dataset(data, group="posterior")
    var_names = _var_names(var_names, dataset)

    dataset = dataset if var_names is None else dataset[var_names]

    ufunc_kwargs = {"ravel": False}
    func_kwargs = {}
    return _wrap_xarray_ufunc(
        rhat_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
    )
github arviz-devs / arviz / arviz / plots / densityplot.py View on Github external
Shade plots and/or remove outlines

    .. plot::
        :context: close-figs

        >>> az.plot_density([centered, non_centered], var_names=["mu"], outline=False, shade=.8)

    Specify binwidth for kernel density estimation

    .. plot::
        :context: close-figs

        >>> az.plot_density([centered, non_centered], var_names=["mu"], bw=.9)
    """
    if not isinstance(data, (list, tuple)):
        datasets = [convert_to_dataset(data, group=group)]
    else:
        datasets = [convert_to_dataset(datum, group=group) for datum in data]

    var_names = _var_names(var_names, datasets)

    if point_estimate not in ("mean", "median", None):
        raise ValueError(
            "Point estimate should be 'mean'," "median' or None, not {}".format(point_estimate)
        )

    n_data = len(datasets)

    if data_labels is None:
        if n_data > 1:
            data_labels = ["{}".format(idx) for idx in range(n_data)]
        else: