How to use the xgboost.plot_tree function in xgboost

To help you get started, we’ve selected a few xgboost examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github dmlc / xgboost / tests / python / test_with_sklearn.py View on Github external
matplotlib.use('Agg')

    from matplotlib.axes import Axes
    from graphviz import Source

    ax = xgb.plot_importance(classifier)
    assert isinstance(ax, Axes)
    assert ax.get_title() == 'Feature importance'
    assert ax.get_xlabel() == 'F score'
    assert ax.get_ylabel() == 'Features'
    assert len(ax.patches) == 4

    g = xgb.to_graphviz(classifier, num_trees=0)
    assert isinstance(g, Source)

    ax = xgb.plot_tree(classifier, num_trees=0)
    assert isinstance(ax, Axes)
github dmlc / xgboost / tests / python / test_plotting.py View on Github external
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
                                 title=None, xlabel=None, ylabel=None)
        assert isinstance(ax, Axes)
        assert ax.get_title() == ''
        assert ax.get_xlabel() == ''
        assert ax.get_ylabel() == ''
        assert len(ax.patches) == 4
        assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0)  # red
        assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0)  # red
        assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0)  # blue
        assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0)  # blue

        g = xgb.to_graphviz(bst2, num_trees=0)
        assert isinstance(g, Source)

        ax = xgb.plot_tree(bst2, num_trees=0)
        assert isinstance(ax, Axes)
github rupskygill / ML-mastery / xgboost_with_python_code / 07_plot_tree-left-to-right.py View on Github external
# plot decision tree
from numpy import loadtxt
from xgboost import XGBClassifier
from xgboost import plot_tree
from matplotlib import pyplot
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=",")
# split data into X and y
X = dataset[:,0:8]
y = dataset[:,8]
# fit model no training data
model = XGBClassifier()
model.fit(X, y)
# plot single tree
plot_tree(model, num_trees=0, rankdir='LR')
pyplot.show()
github Ashton-Sidhu / aethos / aethos / modelling / model_analysis.py View on Github external
graph = Source(
                sklearn.tree.export_graphviz(
                    self.model,
                    out_file=None,
                    feature_names=self.features,
                    class_names=classes,
                    rounded=True,
                    precision=True,
                    filled=True,
                )
            )

            display(SVG(graph.pipe(format="svg")))

        elif isinstance(self.model, xgb.XGBModel):
            return xgb.plot_tree(self.model)

        elif isinstance(self.model, lgb.sklearn.LGBMModel):
            return lgb.plot_tree(self.model)

        elif isinstance(self.model, cb.CatBoost):
            return self.model.plot_tree(tree_idx=tree_num, pool=self.pool)

        elif isinstance(self.model, sklearn.ensemble.BaseEnsemble):
            estimator = self.model.estimators_[tree_num]

            graph = Source(
                sklearn.tree.export_graphviz(
                    estimator,
                    out_file=None,
                    feature_names=self.features,
                    class_names=classes,
github Ashton-Sidhu / aethos / aethos / modelling / model_analysis.py View on Github external
graph = Source(
                sklearn.tree.export_graphviz(
                    self.model,
                    out_file=None,
                    feature_names=self.features,
                    class_names=classes,
                    rounded=True,
                    precision=True,
                    filled=True,
                )
            )

            display(SVG(graph.pipe(format="svg")))

        elif isinstance(self.model, xgb.XGBModel):
            return xgb.plot_tree(self.model)

        elif isinstance(self.model, lgb.sklearn.LGBMModel):
            return lgb.plot_tree(self.model)

        elif isinstance(self.model, cb.CatBoost):
            return self.model.plot_tree(tree_idx=tree_num, pool=self.pool)

        elif isinstance(self.model, sklearn.ensemble.BaseEnsemble):
            estimator = self.model.estimators_[tree_num]

            graph = Source(
                sklearn.tree.export_graphviz(
                    estimator,
                    out_file=None,
                    feature_names=self.features,
                    class_names=classes,
github google-research / google-research / sparse_data / exp_framework / gbdt.py View on Github external
def plot_tree(model, directory, num_tree=10):
  """Creates and saves a plot of the trees in a gradient boosted tree model.

  Args:
    model: xgb.Booster trained XGBoost gradient boosted trees model
    directory: string directory of save location
    num_tree: number of trees to plot
  """
  base_path = '{}/{}'.format(FILES_PATH, directory)
  os.makedirs(base_path)

  for tree_idx in range(num_tree):
    xgb.plot_tree(model, num_trees=tree_idx)
    fig = plt.gcf()
    fig.set_size_inches(120, 120)

    path = '{}/tree-{}.png'.format(base_path, tree_idx)
    fig.savefig(path)

  logging.info('Saved plots to: %s \n', base_path)
  plt.close('all')
github colinmorris / instacart-basket-prediction / nonrecurrent / plot_model.py View on Github external
def plot_trees(n=2, ax=None, rankdir='UT'): # 'LR' for left-right
  xgb.plot_tree(model, num_trees=n, rankdir=rankdir, ax=ax)