Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
matplotlib.use('Agg')
from matplotlib.axes import Axes
from graphviz import Source
ax = xgb.plot_importance(classifier)
assert isinstance(ax, Axes)
assert ax.get_title() == 'Feature importance'
assert ax.get_xlabel() == 'F score'
assert ax.get_ylabel() == 'Features'
assert len(ax.patches) == 4
g = xgb.to_graphviz(classifier, num_trees=0)
assert isinstance(g, Source)
ax = xgb.plot_tree(classifier, num_trees=0)
assert isinstance(ax, Axes)
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
title=None, xlabel=None, ylabel=None)
assert isinstance(ax, Axes)
assert ax.get_title() == ''
assert ax.get_xlabel() == ''
assert ax.get_ylabel() == ''
assert len(ax.patches) == 4
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Source)
ax = xgb.plot_tree(bst2, num_trees=0)
assert isinstance(ax, Axes)
# plot decision tree
from numpy import loadtxt
from xgboost import XGBClassifier
from xgboost import plot_tree
from matplotlib import pyplot
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=",")
# split data into X and y
X = dataset[:,0:8]
y = dataset[:,8]
# fit model no training data
model = XGBClassifier()
model.fit(X, y)
# plot single tree
plot_tree(model, num_trees=0, rankdir='LR')
pyplot.show()
graph = Source(
sklearn.tree.export_graphviz(
self.model,
out_file=None,
feature_names=self.features,
class_names=classes,
rounded=True,
precision=True,
filled=True,
)
)
display(SVG(graph.pipe(format="svg")))
elif isinstance(self.model, xgb.XGBModel):
return xgb.plot_tree(self.model)
elif isinstance(self.model, lgb.sklearn.LGBMModel):
return lgb.plot_tree(self.model)
elif isinstance(self.model, cb.CatBoost):
return self.model.plot_tree(tree_idx=tree_num, pool=self.pool)
elif isinstance(self.model, sklearn.ensemble.BaseEnsemble):
estimator = self.model.estimators_[tree_num]
graph = Source(
sklearn.tree.export_graphviz(
estimator,
out_file=None,
feature_names=self.features,
class_names=classes,
graph = Source(
sklearn.tree.export_graphviz(
self.model,
out_file=None,
feature_names=self.features,
class_names=classes,
rounded=True,
precision=True,
filled=True,
)
)
display(SVG(graph.pipe(format="svg")))
elif isinstance(self.model, xgb.XGBModel):
return xgb.plot_tree(self.model)
elif isinstance(self.model, lgb.sklearn.LGBMModel):
return lgb.plot_tree(self.model)
elif isinstance(self.model, cb.CatBoost):
return self.model.plot_tree(tree_idx=tree_num, pool=self.pool)
elif isinstance(self.model, sklearn.ensemble.BaseEnsemble):
estimator = self.model.estimators_[tree_num]
graph = Source(
sklearn.tree.export_graphviz(
estimator,
out_file=None,
feature_names=self.features,
class_names=classes,
def plot_tree(model, directory, num_tree=10):
"""Creates and saves a plot of the trees in a gradient boosted tree model.
Args:
model: xgb.Booster trained XGBoost gradient boosted trees model
directory: string directory of save location
num_tree: number of trees to plot
"""
base_path = '{}/{}'.format(FILES_PATH, directory)
os.makedirs(base_path)
for tree_idx in range(num_tree):
xgb.plot_tree(model, num_trees=tree_idx)
fig = plt.gcf()
fig.set_size_inches(120, 120)
path = '{}/tree-{}.png'.format(base_path, tree_idx)
fig.savefig(path)
logging.info('Saved plots to: %s \n', base_path)
plt.close('all')
def plot_trees(n=2, ax=None, rankdir='UT'): # 'LR' for left-right
xgb.plot_tree(model, num_trees=n, rankdir=rankdir, ax=ax)