【第二章:机器学习与神经网络概述】04.回归算法理论与实践 -(3)决策树回归模型(Decision Tree Regression)

第二章: 机器学习与神经网络概述

第四部分:回归算法理论与实践

第三节:决策树回归模型

内容:剪枝方法、回归树结构与算法实现。

决策树回归模型是一种非参数的监督学习方法,通过将特征空间划分为多个区域,在每个区域内做常数预测,适合处理非线性回归问题、特征交互明显的数据集。


一、基本原理

决策树回归以CART(Classification and Regression Trees)算法为基础,通过不断划分特征空间,构建一棵回归树:

  • 每个内部节点表示对某一特征的判断;

  • 每个叶节点表示一个预测值(区域内样本均值);

  • 划分依据:最小化划分后区域内的均方误差(MSE)。


二、划分准则与误差计算

对样本集 D,假设以特征 x_j 的值 s 作为划分点,将样本划分为:

  • D_{\text{left}} = \{x \in D \mid x_j \leq s\}

  • D_{\text{right}} = \{x \in D \mid x_j > s\}

其目标是最小化总的平方误差:

\min_{j,s} \left[ \sum_{x_i \in D_{\text{left}}} (y_i - \bar{y}_{\text{left}})^2 + \sum_{x_i \in D_{\text{right}}} (y_i - \bar{y}_{\text{right}})^2 \right]


三、剪枝策略(Pruning)

决策树容易过拟合,需通过剪枝来控制复杂度:

1. 预剪枝(Pre-Pruning)
  • 在构建过程中提前停止划分:

    • 达到最大深度 max_depth

    • 每个节点最小样本数 min_samples_split

    • MSE 减少小于阈值

2. 后剪枝(Post-Pruning)
  • 先生成整棵树,再从底向上剪去“收益小”的分支(如 sklearn 的 ccp_alpha 参数)

  • 剪枝目标:在保留预测能力的前提下降低模型复杂度


四、Python 实现示例(使用 sklearn)

from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt

# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_full = DecisionTreeRegressor()
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)

# 训练
reg_full.fit(X, y)
reg_pruned.fit(X, y)
reg_ccp.fit(X, y)

plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False

# 可视化
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=20, label="data", color="black")
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", linewidth=2)
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (depth=3)", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", linestyle=":")
plt.legend()
plt.title("回归树剪枝效果对比")
plt.xlabel("X")
plt.ylabel("y")
plt.grid(True)
plt.tight_layout()
plt.show()

 


五、优缺点分析

优点缺点
逻辑简单、易理解容易过拟合,需要剪枝
可处理非线性和多维特征交互对微小变化敏感,稳定性差
不需标准化或归一化对样本数量和分布较敏感
可解释性强(树结构明确)難以推广:小数据表现好,大数据可能需集成优化

六、模型调参建议

参数作用建议
max_depth限制树的最大深度控制模型复杂度,避免过拟合
min_samples_split拆分内部节点所需最小样本数增大可减少模型复杂度
min_samples_leaf每个叶子节点的最小样本数增大有助于平滑预测结果
ccp_alpha后剪枝惩罚系数(复杂度代价剪枝)自动调节树结构,可结合验证集选择最佳值

七、典型应用场景

  • 房价预测(特征离散明显)

  • 电商销售量预测

  • 时间序列短期预测(可结合滑窗技术)

  • 特征交互复杂但不多的中小数据集建模


补充:

可视化决策树结构
from sklearn.tree import plot_tree
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt

# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_pruned = DecisionTreeRegressor(max_depth=3)

# 训练
reg_pruned.fit(X, y)

plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False


plt.figure(figsize=(12, 6))
plot_tree(reg_pruned, filled=True, feature_names=["X"], rounded=True)
plt.title("回归树结构(max_depth=3)")
plt.show()

 


回归树结构图(plot_tree)
from sklearn.tree import DecisionTreeRegressor, plot_tree
import matplotlib.pyplot as plt
import numpy as np

# 构造样本数据
X = np.array([[1], [2], [3], [4], [5], [6], [7], [8]])
y = np.array([5, 4.5, 4, 3.5, 3, 2.5, 2, 1.5])

# 创建并训练模型
tree = DecisionTreeRegressor(max_depth=3, random_state=42)
tree.fit(X, y)

# 可视化决策树结构
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plot_tree(tree, feature_names=["X"], filled=True, rounded=True)
plt.title("回归树结构图 (max_depth=3)")
plt.show()

 


剪枝前后预测曲线对比图
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt

# 构造数据
rng = np.random.RandomState(0)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + rng.normal(0, 0.1, X.shape[0])

# 不剪枝模型
reg_full = DecisionTreeRegressor()
reg_full.fit(X, y)

# 预剪枝模型(限制最大深度)
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_pruned.fit(X, y)

# 后剪枝模型(设置复杂度惩罚参数)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)
reg_ccp.fit(X, y)

# 测试数据
X_test = np.linspace(0, 5, 500).reshape(-1, 1)

# 可视化
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plt.scatter(X, y, label="Train Data", color="black", s=20)
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", color="blue")
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (max_depth=3)", color="green", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", color="red", linestyle=":")
plt.title("回归树剪枝对比图")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值