How to use the shap.plots.colors function in shap

To help you get started, we’ve selected a few shap examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github slundberg / shap / shap / benchmark / plots.py View on Github external
def get_method_color(method):
    for l in getattr(methods, method).__doc__.split("\n"):
        l = l.strip()
        if l.startswith("color = "):
            v = l.split("=")[1].strip()
            if v.startswith("red_blue_circle("):
                return colors.red_blue_circle(float(v[16:-1]))
            else:
                return v
    return "#000000"
github slundberg / shap / shap / plots / waterfall.py View on Github external
fig = pl.gcf()
    ax = pl.gca()
    xticks = ax.get_xticks()
    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    width, height = bbox.width, bbox.height
    bbox_to_xscale = xlen/width
    hl_scaled = bbox_to_xscale * head_length
    renderer = fig.canvas.get_renderer()
    
    # draw the positive arrows
    for i in range(len(pos_inds)):
        dist = pos_widths[i]
        arrow_obj = pl.arrow(
            pos_lefts[i], pos_inds[i], max(dist-hl_scaled, 0.000001), 0,
            head_length=min(dist, hl_scaled),
            color=colors.red_rgb, width=bar_width,
            head_width=bar_width
        )
        
        txt_obj = pl.text(
            pos_lefts[i] + 0.5*dist, pos_inds[i], format_value(pos_widths[i], '%+0.02f'),
            horizontalalignment='center', verticalalignment='center', color="white",
            fontsize=12
        )
        text_bbox = txt_obj.get_window_extent(renderer=renderer)
        arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
        if text_bbox.width > arrow_bbox.width: 
            txt_obj.remove()
    
    # draw the negative arrows
    for i in range(len(neg_inds)):
        dist = neg_widths[i]
github slundberg / shap / shap / plots / bar.py View on Github external
if max_display is None:
        max_display = 7
    else:
        max_display = min(len(feature_names), max_display)
        
    
    feature_order = np.argsort(-np.abs(shap_values))
    
    # 
    feature_inds = feature_order[:max_display]
    y_pos = np.arange(len(feature_inds), 0, -1)
    pl.barh(
        y_pos, shap_values[feature_inds],
        0.7, align='center',
        color=[colors.red_rgb if shap_values[feature_inds[i]] > 0 else colors.blue_rgb for i in range(len(y_pos))]
    )
    pl.yticks(y_pos, fontsize=13)
    if features is not None:
        features = list(features)

        # try and round off any trailing zeros after the decimal point in the feature values
        for i in range(len(features)):
            try:
                if round(features[i]) == features[i]:
                    features[i] = int(features[i])
            except TypeError:
                pass # features[i] must not be a number
    yticklabels = []
    for i in feature_inds:
        if features is not None:
            yticklabels.append(feature_names[i] + " = " + str(features[i]))
github KienVu2368 / tabint / tabint / interpretation.py View on Github external
plot_types = ['contour', 'grid'] if plot_types is None else [plot_types]
        for plot_type in plot_types:
            figs, ax = pdp.pdp_interact_plot(
                pdp_interact_out = ft_plot,
                feature_names = var_name or feature,
                plot_type= plot_type, plot_pdp=True,
                which_classes=which_classes, plot_params = plot_params)
        plt.show()

    def sample(self, sample): return self.df if sample is None else self.df.sample(sample)

# Cell
#harcode to change shap color
green_blue = LinearSegmentedColormap.from_list('custom blue', [(0, '#ffff00'), (1, '#002266')], N=256)
cl.red_blue = green_blue
cl.red_blue_solid = green_blue

# Cell
class Shapley:
    """
    SHAP value: https://github.com/slundberg/shap
    """
    def __init__(self, explainer, shap_values, df, df_disp, features):
        shap.initjs()
        self.explainer = explainer
        self.shap_values = shap_values
        self.df, self.df_disp, self.features = df, df_disp, features

    @classmethod
    def from_Tree(cls, learner, ds, df_disp = None, sample = 10000, remove_outlier = True):

        if remove_outlier:
github slundberg / shap / shap / plots / monitoring.py View on Github external
ys = shap_values[:,ind]
    xs = np.arange(len(ys))#np.linspace(0, 12*2, len(ys))
    
    pvals = []
    inc = 50
    for i in range(inc, len(ys)-inc, inc):
        #stat, pval = scipy.stats.mannwhitneyu(v[:i], v[i:], alternative="two-sided")
        stat, pval = scipy.stats.ttest_ind(ys[:i], ys[i:])
        pvals.append(pval)
    min_pval = np.min(pvals)
    min_pval_ind = np.argmin(pvals)*inc + inc
    
    if min_pval < 0.05 / shap_values.shape[1]:
        pl.axvline(min_pval_ind, linestyle="dashed", color="#666666", alpha=0.2)
        
    pl.scatter(xs, ys, s=10, c=features[:,ind], cmap=colors.red_blue)
    
    pl.xlabel("Sample index")
    pl.ylabel(truncate_text(feature_names[ind], 30) + "\nSHAP value", size=13)
    pl.gca().xaxis.set_ticks_position('bottom')
    pl.gca().yaxis.set_ticks_position('left')
    pl.gca().spines['right'].set_visible(False)
    pl.gca().spines['top'].set_visible(False)
    cb = pl.colorbar()
    cb.outline.set_visible(False)
    bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
    cb.ax.set_aspect((bbox.height - 0.7) * 20)
    cb.set_label(truncate_text(feature_names[ind], 30), size=13)
    if show:
        pl.show()
github slundberg / shap / shap / plots / embedding.py View on Github external
else:
        cvals = shap_values[:,ind]
        fname = feature_names[ind]
    
    # see if we need to compute the embedding
    if type(method) == str and method == "pca":
        pca = sklearn.decomposition.PCA(2)
        embedding_values = pca.fit_transform(shap_values)
    elif hasattr(method, "shape") and method.shape[1] == 2:
        embedding_values = method
    else:
        print("Unsupported embedding method:", method)

    pl.scatter(
        embedding_values[:,0], embedding_values[:,1], c=cvals,
        cmap=colors.red_blue, alpha=alpha, linewidth=0
    )
    pl.axis("off")
    #pl.title(feature_names[ind])
    cb = pl.colorbar()
    cb.set_label("SHAP value for\n"+fname, size=13)
    cb.outline.set_visible(False)
    
    
    pl.gcf().set_size_inches(7.5, 5)
    bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
    cb.ax.set_aspect((bbox.height - 0.7) * 10)
    cb.set_alpha(1)
    if show:
        pl.show()
github KienVu2368 / tabint / tabint / interpretation.py View on Github external
plot_types = ['contour', 'grid'] if plot_types is None else [plot_types]
        for plot_type in plot_types:
            figs, ax = pdp.pdp_interact_plot(
                pdp_interact_out = ft_plot,
                feature_names = var_name or feature,
                plot_type= plot_type, plot_pdp=True,
                which_classes=which_classes, plot_params = plot_params)
        plt.show()

    def sample(self, sample): return self.df if sample is None else self.df.sample(sample)

# Cell
#harcode to change shap color
green_blue = LinearSegmentedColormap.from_list('custom blue', [(0, '#ffff00'), (1, '#002266')], N=256)
cl.red_blue = green_blue
cl.red_blue_solid = green_blue

# Cell
class Shapley:
    """
    SHAP value: https://github.com/slundberg/shap
    """
    def __init__(self, explainer, shap_values, df, df_disp, features):
        shap.initjs()
        self.explainer = explainer
        self.shap_values = shap_values
        self.df, self.df_disp, self.features = df, df_disp, features

    @classmethod
    def from_Tree(cls, learner, ds, df_disp = None, sample = 10000, remove_outlier = True):
github slundberg / shap / shap / plots / decision.py View on Github external
# create a symmetric axis around base_value
        a, b = (base_value - xmin), (xmax - base_value)
        if a > b:
            xlim = (base_value - a, base_value + a)
        else:
            xlim = (base_value - b, base_value + b)
        # Adjust xlim to include a little visual margin.
        a = (xlim[1] - xlim[0]) * 0.02
        xlim = (xlim[0] - a, xlim[1] + a)

    # Initialize style arguments
    if alpha is None:
        alpha = 1.0

    if plot_color is None:
        plot_color = colors.red_blue

    __decision_plot_matplotlib(
        base_value,
        cumsum,
        ascending,
        feature_display_count,
        features_display,
        feature_names_display,
        highlight,
        plot_color,
        axis_color,
        y_demarc_color,
        xlim,
        alpha,
        color_bar,
        auto_size_plot,
github slundberg / shap / shap / plots / image.py View on Github external
else:
            x_curr_gray = x_curr

        axes[row,0].imshow(x_curr, cmap=pl.get_cmap('gray'))
        axes[row,0].axis('off')
        if len(shap_values[0][row].shape) == 2:
            abs_vals = np.stack([np.abs(shap_values[i]) for i in range(len(shap_values))], 0).flatten()
        else:
            abs_vals = np.stack([np.abs(shap_values[i].sum(-1)) for i in range(len(shap_values))], 0).flatten()
        max_val = np.nanpercentile(abs_vals, 99.9)
        for i in range(len(shap_values)):
            if labels is not None:
                axes[row,i+1].set_title(labels[row,i], **label_kwargs)
            sv = shap_values[i][row] if len(shap_values[i][row].shape) == 2 else shap_values[i][row].sum(-1)
            axes[row,i+1].imshow(x_curr_gray, cmap=pl.get_cmap('gray'), alpha=0.15, extent=(-1, sv.shape[0], sv.shape[1], -1))
            im = axes[row,i+1].imshow(sv, cmap=colors.red_transparent_blue, vmin=-max_val, vmax=max_val)
            axes[row,i+1].axis('off')
    if hspace == 'auto':
        fig.tight_layout()
    else:
        fig.subplots_adjust(hspace=hspace)
    cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal", aspect=fig_size[0]/aspect)
    cb.outline.set_visible(False)
    if show:
        pl.show()
github slundberg / shap / shap / plots / partial_dependence.py View on Github external
features_tmp[:,ind0] = xs0[i]
                features_tmp[:,ind1] = xs1[j]
                x0[i, j] = xs0[i]
                x1[i, j] = xs1[j]
                vals[i, j] = model(features_tmp).mean()
                
        fig = pl.figure()
        ax = fig.add_subplot(111, projection='3d')


#         x = y = np.arange(-3.0, 3.0, 0.05)
#         X, Y = np.meshgrid(x, y)
#         zs = np.array(fun(np.ravel(X), np.ravel(Y)))
#         Z = zs.reshape(X.shape)

        ax.plot_surface(x0, x1, vals, cmap=shap.plots.colors.red_blue_transparent)

        ax.set_xlabel(feature_names[ind0], fontsize=13)
        ax.set_ylabel(feature_names[ind1], fontsize=13)
        ax.set_zlabel("E[f(x) | "+ str(feature_names[ind0]) + ", "+ str(feature_names[ind1]) + "]", fontsize=13)

        if show:
            pl.show()
        else:      
            return fig, ax