How to use the arviz.plots.plot_utils.make_label function in arviz

To help you get started, we’ve selected a few arviz examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github arviz-devs / arviz / arviz / plots / backends / bokeh / forestplot.py View on Github external
skip_dims = {"chain"}
        else:
            grouped_data = [datum.groupby("chain") for datum in self.data]
            skip_dims = set()

        label_dict = OrderedDict()
        for name, grouped_datum in zip(self.model_names, grouped_data):
            for _, sub_data in grouped_datum:
                datum_iter = xarray_var_iter(
                    sub_data,
                    var_names=[self.var_name],
                    skip_dims=skip_dims,
                    reverse_selections=True,
                )
                for _, selection, values in datum_iter:
                    label = make_label(self.var_name, selection, position="beside")
                    if label not in label_dict:
                        label_dict[label] = OrderedDict()
                    if name not in label_dict[label]:
                        label_dict[label][name] = []
                    label_dict[label][name].append(values)

        y = self.y_start
        for label, model_data in label_dict.items():
            for model_name, value_list in model_data.items():
                if model_name:
                    row_label = "{}: {}".format(model_name, label)
                else:
                    row_label = label
                for values in value_list:
                    yield y, row_label, label, values, self.model_color[model_name]
                    y += self.chain_offset
github arviz-devs / arviz / arviz / plots / backends / bokeh / rankplot.py View on Github external
ax.xaxis.axis_label = "Rank (all chains)"

            ax.yaxis.ticker = FixedTicker(ticks=y_ticks)
            ax.xaxis.major_label_overrides = dict(
                zip(map(str, y_ticks), map(str, range(len(y_ticks))))
            )

        else:
            ax.yaxis.major_tick_line_color = None
            ax.yaxis.minor_tick_line_color = None

            ax.xaxis.major_label_text_font_size = "0pt"
            ax.yaxis.major_label_text_font_size = "0pt"

        _title = Title()
        _title.text = make_label(var_name, selection)
        ax.title = _title

    if backend_show(show):
        grid = gridplot(axes.tolist(), toolbar_location="above")
        bkp.show(grid)

    return axes
github arviz-devs / arviz / arviz / plots / backends / bokeh / mcseplot.py View on Github external
hline = Span(
                    location=0,
                    dimension="width",
                    line_color="black",
                    line_width=_linewidth,
                    line_alpha=0.7,
                )

            ax_.renderers.append(hline)

            glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs)
            cds_rug = ColumnDataSource({"rug_x": np.asarray(rug_x), "rug_y": np.asarray(rug_y)})
            ax_.add_glyph(cds_rug, glyph)

        title = Title()
        title.text = make_label(var_name, selection)
        ax_.title = title

        ax_.xaxis.axis_label = "Quantile"
        ax_.yaxis.axis_label = (
            r"Value $\pm$ MCSE for quantiles" if errorbar else "MCSE for quantiles"
        )

        if not errorbar:
            ax_.y_range._property_values["start"] = -0.05  # pylint: disable=protected-access
            ax_.y_range._property_values["end"] = 1  # pylint: disable=protected-access

    if backend_show(show):
        grid = gridplot(ax.tolist(), toolbar_location="above")
        bkp.show(grid)

    return ax
github arviz-devs / arviz / arviz / plots / backends / matplotlib / mcseplot.py View on Github external
rug_kwargs.setdefault("marker", "|")
            rug_kwargs.setdefault("linestyle", rug_kwargs.pop("ls", "None"))
            rug_kwargs.setdefault("color", rug_kwargs.pop("c", kwargs.get("color", "C0")))
            rug_kwargs.setdefault("space", 0.1)
            rug_kwargs.setdefault("markersize", rug_kwargs.pop("ms", 2 * _markersize))

            mask = idata.sample_stats[rug_kind].values.flatten()
            values = rankdata(values)[mask]
            y_min, y_max = ax_.get_ylim()
            y_min = y_min if errorbar else 0
            rug_space = (y_max - y_min) * rug_kwargs.pop("space")
            rug_x, rug_y = values / (len(mask) - 1), np.full_like(values, y_min) - rug_space
            ax_.plot(rug_x, rug_y, **rug_kwargs)
            ax_.axhline(y_min, color="k", linewidth=_linewidth, alpha=0.7)

        ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True)
        ax_.tick_params(labelsize=xt_labelsize)
        ax_.set_xlabel("Quantile", fontsize=ax_labelsize, wrap=True)
        ax_.set_ylabel(
            r"Value $\pm$ MCSE for quantiles" if errorbar else "MCSE for quantiles",
            fontsize=ax_labelsize,
            wrap=True,
        )
        ax_.set_xlim(0, 1)
        if rug:
            ax_.yaxis.get_major_locator().set_params(nbins="auto", steps=[1, 2, 5, 10])
            y_min, y_max = ax_.get_ylim()
            yticks = ax_.get_yticks()
            yticks = yticks[(yticks >= y_min) & (yticks < y_max)]
            ax_.set_yticks(yticks)
            ax_.set_yticklabels(["{:.3g}".format(ytick) for ytick in yticks])
        elif not errorbar:
github arviz-devs / arviz / arviz / plots / backends / bokeh / densityplot.py View on Github external
length_plotters,
        rows,
        cols,
        figsize=figsize,
        squeeze=False,
        backend="bokeh",
        backend_kwargs=backend_kwargs,
    )

    axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())}
    if data_labels is None:
        data_labels = {}

    for m_idx, plotters in enumerate(to_plot):
        for ax_idx, (var_name, selection, values) in enumerate(plotters):
            label = make_label(var_name, selection)

            if data_labels:
                data_label = data_labels[m_idx]
                if ax_idx != 0 or data_label == "":
                    data_label = None
            else:
                data_label = None

            _d_helper(
                values.flatten(),
                label,
                colors[m_idx],
                bw,
                line_width,
                markersize,
                credible_interval,
github arviz-devs / arviz / arviz / plots / backends / bokeh / traceplot.py View on Github external
x_name=draw_name,
                    y_name=y_name,
                    colors=colors,
                    combined=combined,
                    rug=rug,
                    legend=legend,
                    trace_kwargs=trace_kwargs,
                    hist_kwargs=hist_kwargs,
                    plot_kwargs=plot_kwargs,
                    fill_kwargs=fill_kwargs,
                    rug_kwargs=rug_kwargs,
                )

        for col in (0, 1):
            _title = Title()
            _title.text = make_label(var_name, selection)
            axes[idx, col].title = _title

        for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection):
            if isinstance(vlines, (float, int)):
                line_values = [vlines]
            else:
                line_values = np.atleast_1d(vlines).ravel()

            for line_value in line_values:
                vline = Span(
                    location=line_value,
                    dimension="height",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=0.75,
                )
github arviz-devs / arviz / arviz / plots / backends / matplotlib / forestplot.py View on Github external
skip_dims = {"chain"}
        else:
            grouped_data = [datum.groupby("chain") for datum in self.data]
            skip_dims = set()

        label_dict = OrderedDict()
        for name, grouped_datum in zip(self.model_names, grouped_data):
            for _, sub_data in grouped_datum:
                datum_iter = xarray_var_iter(
                    sub_data,
                    var_names=[self.var_name],
                    skip_dims=skip_dims,
                    reverse_selections=True,
                )
                for _, selection, values in datum_iter:
                    label = make_label(self.var_name, selection, position="beside")
                    if label not in label_dict:
                        label_dict[label] = OrderedDict()
                    if name not in label_dict[label]:
                        label_dict[label][name] = []
                    label_dict[label][name].append(values)

        y = self.y_start
        for label, model_data in label_dict.items():
            for model_name, value_list in model_data.items():
                if model_name:
                    row_label = "{}: {}".format(model_name, label)
                else:
                    row_label = label
                for values in value_list:
                    yield y, row_label, label, values, self.model_color[model_name]
                    y += self.chain_offset
github arviz-devs / arviz / arviz / plots / backends / matplotlib / densityplot.py View on Github external
):
    """Matplotlib densityplot."""
    _, ax = _create_axes_grid(
        length_plotters,
        rows,
        cols,
        figsize=figsize,
        squeeze=False,
        backend="matplotlib",
        backend_kwargs=backend_kwargs,
    )
    axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())}

    for m_idx, plotters in enumerate(to_plot):
        for var_name, selection, values in plotters:
            label = make_label(var_name, selection)
            _d_helper(
                values.flatten(),
                label,
                colors[m_idx],
                bw,
                titlesize,
                xt_labelsize,
                linewidth,
                markersize,
                credible_interval,
                point_estimate,
                hpd_markers,
                outline,
                shade,
                axis_map[label],
            )