Open In App

Training a decision tree against unbalanced data

Last Updated : 21 Apr, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

When building machine learning models like decision trees, imbalanced data—where one class significantly outnumbers another—presents a unique challenge. For example, in credit card transactions, legitimate transactions may vastly outnumber fraudulent ones. Without adjustments, decision trees tend to favor the majority class, often misclassifying the minority class.

Decision trees are effective classifiers, constructing logical rules to separate classes. However, on imbalanced datasets, trees tend to maximize node purity, leading them to favor the majority class at the expense of the minority class due to: Decision trees often emphasize the more frequent class, which can lead to high misclassification rates for the minority class.

Strategies to Address Imbalance

  1. Cost-Sensitive Learning: Assign higher misclassification costs to the minority class (e.g., fraud), encouraging the model to consider it more seriously.
  2. Alternative Splitting Criteria: Use metrics like Hellinger distance instead of traditional ones like information gain, as it better handles skewed distributions.
  3. Sampling Techniques: Balance the dataset by oversampling the minority class or undersampling the majority class, or by using wrapper frameworks that combine sampling with the splitting metric.
  4. Adjusted Evaluation Metrics: Accuracy alone is misleading in imbalanced settings. Instead, prioritize metrics like precision, recall, and F1-score to assess the model’s performance on the minority class more accurately.

Example in Practice

Consider a decision tree classifying medical records as healthy or diseased. With 99.5% healthy cases, the tree may default to "healthy" predictions, yielding high accuracy but misclassifying all diseased cases. By applying cost-sensitive learning or alternative splitting criteria, the tree can prioritize correct classification of the minority diseased cases.

Let's now see a code example in this scenario:

Training a decision tree against unbalanced data

Creating An Unbalanced Dataset

Python
# Import necessary libraries
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, roc_auc_score, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
from collections import Counter

# Create an imbalanced dataset
X, y = make_classification(
    n_classes=3,
    class_sep=2,
    weights=[0.7, 0.2, 0.1],
    n_informative=5,
    n_redundant=0,
    flip_y=0,
    n_features=10,
    n_clusters_per_class=1,
    n_samples=1000,
    random_state=42
)

# Display class distribution
print("Class distribution:", Counter(y))

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)


This code generates an imbalanced synthetic dataset with three classes using make_classification, with class distribution controlled by the weights parameter (70%, 20%, and 10% for each class). It then displays the class distribution and splits the data into training and testing sets, preparing it for training a classification model.

Output:

Class distribution: Counter({np.int64(0): 700, np.int64(1): 200, np.int64(2): 100})

Training the Decision Tree with Adjustments

Python
# Initialize the Decision Tree Classifier with custom weights and controlled depth
dt_classifier = DecisionTreeClassifier(
    criterion='gini', 
    max_depth=4, 
    min_samples_leaf=5, 
    class_weight='balanced', 
    random_state=42
)

# Train the model
dt_classifier.fit(X_train, y_train)

# Make predictions
y_pred = dt_classifier.predict(X_test)

# Evaluate the model with precision, recall, and F1-score
print("\nClassification Report:\n", classification_report(y_test, y_pred))

# Calculate ROC AUC Score (requires binary or one-vs-all approach for multi-class)
roc_auc = roc_auc_score(
    y_test, 
    dt_classifier.predict_proba(X_test), 
    multi_class='ovr'
)
print(f"ROC AUC Score: {roc_auc:.2f}")


This code initializes a decision tree classifier with balanced class weights to address imbalanced data, sets a controlled depth (max_depth=4) and minimum leaf samples for stability. After training, it evaluates the model with precision, recall, and F1 scores, and calculates the ROC AUC score using a one-vs-rest approach for multi-class. Finally, it displays the confusion matrix and visualizes the decision tree structure.

Output:

Classification Report:
precision recall f1-score support

0 1.00 1.00 1.00 217
1 1.00 0.98 0.99 56
2 0.96 1.00 0.98 27

accuracy 1.00 300
macro avg 0.99 0.99 0.99 300
weighted avg 1.00 1.00 1.00 300

ROC AUC Score: 1.00

The benefit of using balanced trees can be observed in the following decision boundaries:

Decision_trees
Difference between decision boundaries by using default and balanced weights


Key Takeaways

  • Recognize Imbalance Effects: Imbalanced data skews decision trees toward the majority class.
  • Apply Balancing Strategies: Cost-sensitive learning, alternative splitting, and sampling improve minority class predictions.
  • Use Appropriate Metrics: Precision, recall, and F1-score provide a clearer picture of model performance on the minority class.

Next Article

Similar Reads