How to visualise XGBoost tree in Python?

This recipe helps you visualise XGBoost tree in Python

Recipe Objective

Have you ever tried to plot XGBoost tree in python and visualise it in the form of tree. So here, In this recipe we will be training XGBoost Classifier, predicting the output and plot the graph.

So this is the recipe on how we visualise XGBoost tree in Python

Get Closer To Your Dream of Becoming a Data Scientist with 70+ Solved End-to-End ML Projects

Step 1 - Import the library

from sklearn import datasets from sklearn import metrics from xgboost import XGBClassifier, plot_tree from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt plt.style.use('ggplot')

We have imported all the modules that would be needed like metrics, datasets, XGBClassifier , plot_tree etc. We will see the use of each modules step by step further.

Step 2 - Setting up the Data for Classifier

We have imported inbuilt breast_cancer dataset from the module datasets and stored the data in X and the target in y. We have also used train_test_split to split the dataset into two parts such that 30% of data is in test and rest in train. dataset = datasets.load_breast_cancer() X = dataset.data; y = dataset.target X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30)

Step 3 - Training XGBClassifier and Predicting the output

We have made an object for the model and fitted the train data. Then we have used the test data to test the model by predicting the output from the model for test data. model_XGB = XGBClassifier() model_XGB.fit(X_train, y_train) print(model_XGB) expected_y = y_test predicted_y = model_XGB.predict(X_test)

Step 4 - Calculating the Scores

Now we are calcutaing other scores for the model using classification_report and confusion matrix by passing expected and predicted values of target of test set. print(metrics.classification_report(expected_y, predicted_y, target_names=dataset.target_names)) print(metrics.confusion_matrix(expected_y, predicted_y))

Explore More Data Science and Machine Learning Projects for Practice. Fast-Track Your Career Transition with ProjectPro

Step 5 - Ploting the tree

We are ploting the tree for XGBClassifier by passing the required parameters from plot_tree. plot_tree(model_XGB); plt.show() plot_tree(model_XGB, num_trees=4); plt.show() plot_tree(model_XGB, num_trees=0, rankdir='LR'); plt.show() So the final output comes as:

XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
       n_estimators=100, n_jobs=1, nthread=None,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
       subsample=1, verbosity=1)

XGBClassifier: 

              precision    recall  f1-score   support

   malignant       0.99      0.94      0.96        70
      benign       0.96      0.99      0.98       101

   micro avg       0.97      0.97      0.97       171
   macro avg       0.97      0.97      0.97       171
weighted avg       0.97      0.97      0.97       171


[[ 66   4]
 [  1 100]]

Download Materials


What Users are saying..

profile image

Ed Godalle

Director Data Analytics at EY / EY Tech
linkedin profile url

I am the Director of Data Analytics with over 10+ years of IT experience. I have a background in SQL, Python, and Big Data working with Accenture, IBM, and Infosys. I am looking to enhance my skills... Read More

Relevant Projects

Loan Default Prediction Project using Explainable AI ML Models
Loan Default Prediction Project that employs sophisticated machine learning models, such as XGBoost and Random Forest and delves deep into the realm of Explainable AI, ensuring every prediction is transparent and understandable.

Insurance Pricing Forecast Using XGBoost Regressor
In this project, we are going to talk about insurance forecast by using linear and xgboost regression techniques.

Build a Langchain Streamlit Chatbot for EDA using LLMs
In this LLM project, you will build a Streamlit Chatbot integrated with Langchain technology for natural language interactions with a SQL database, facilitating real-time visualization and insightful insights, streamlining data exploration and analysis.

Learn How to Build a Logistic Regression Model in PyTorch
In this Machine Learning Project, you will learn how to build a simple logistic regression model in PyTorch for customer churn prediction.

Mastering A/B Testing: A Practical Guide for Production
In this A/B Testing for Machine Learning Project, you will gain hands-on experience in conducting A/B tests, analyzing statistical significance, and understanding the challenges of building a solution for A/B testing in a production environment.

Loan Eligibility Prediction in Python using H2O.ai
In this loan prediction project you will build predictive models in Python using H2O.ai to predict if an applicant is able to repay the loan or not.

Build a Refund Request AI Agent using LangGraph and FAISS
In this AI Project, you will learn to build an intelligent, multi-agent refund management system that automates validation of customer requests, enforces business policies through semantic search, and provides consistent customer communication.

Time Series Project to Build a Multiple Linear Regression Model
Learn to build a Multiple linear regression model in Python on Time Series Data

Credit Card Fraud Detection as a Classification Problem
In this data science project, we will predict the credit card fraud in the transactional dataset using some of the predictive models.

Hands-On Approach to Regression Discontinuity Design Python
In this machine learning project, you will learn to implement Regression Discontinuity Design Example in Python to determine the effect of age on Mortality Rate in Python.