import pydotplus
from IPython.display import Image
from sklearn.tree import export_graphviz
import os
os.environ['PATH'] += r'.\Graphviz2.38\bin'
def print_graph(dtr, feature_names):
"""绘制决策树"""
graph = export_graphviz(dtr, feature_names=feature_names, class_names={0:"D", 1:"R"},
label="root", proportion=True, impurity=False, out_file=None,
filled=True, rounded=True)
graph = pydotplus.graph_from_dot_data(graph)
return Image(graph.create_png())
from sklearn.tree import DecisionTreeClassifier
dtr = DecisionTreeClassifier(max_depth=3, random_state=SEED)
dtr.fit(X, y)
print_graph(dtr, X.columns)