一、模块结构概览
import numpy as np
from sklearn.model_selection import cross_validate, learning_curve
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, accuracy_score, recall_score, f1_score
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
import os
依赖说明:
-
numpy
:处理数值计算 -
sklearn
:提供机器学习算法和工具 -
matplotlib
:可视化学习曲线 -
os
:处理文件路径操作
二、核心类定义
2.1 类初始化
class ModelTrainer:
def __init__(self):
pass
功能:创建模型训练器的基础类,当前无需特殊初始化参数
2.2 主训练方法 train_model
2.2.1 数据准备阶段
def train_model(self, X, y, output_dir="model_plots"):
# 创建输出文件夹
os.makedirs(output_dir, exist_ok=True)
# 数据分割
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2, # 20%测试集
stratify=y, # 保持类别分布
random_state=42 # 可重复性种子
)
# 数据标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train) # 训练集拟合+转换
X_test_scaled = scaler.transform(X_test) # 测试集仅转换
# 合并标准化数据
X_scaled = np.concatenate([X_train_scaled, X_test_scaled])
y = np.concatenate([y_train, y_test])
关键技术点:
-
stratify=y
保证分割后的数据保持原始类别分布 -
标准化处理防止特征尺度差异影响模型性能
-
合并数据集用于交叉验证
2.2.2 模型配置
models = {
"Random Forest": RandomForestClassifier(
n_estimators=200, # 增加树数量提升模型容量
max_depth=8, # 限制深度防止过拟合
n_jobs=-1 # 使用全部CPU核心
),
"Linear SVM": SVC(
kernel='rbf', # 选择径向基函数核
C=0.5, # 正则化强度参数
gamma='auto', # 自动计算gamma参数
probability=True # 启用概率估计
),
"KNN": KNeighborsClassifier(
n_neighbors=3, # 使用3近邻
n_jobs=-1 # 并行计算
)
}
scoring = {
'accuracy': make_scorer(accuracy_score),
'recall': make_scorer(recall_score, average='macro'), # 多分类宏平均
'f1': make_scorer(f1_score, average='macro')
}
参数调优说明:
-
随机森林:通过限制
max_depth
平衡偏差-方差 -
SVM:调整
C
值控制正则化强度 -
KNN:小邻域数适合高维度数据
2.2.3 交叉验证流程
best_score = -1
best_model_name = ""
best_model = None
for name, model in models.items():
# 交叉验证
cv_results = cross_validate(
model,
X_scaled,
y,
cv=3, # 3折交叉验证
scoring=scoring # 使用自定义指标
)
# 指标计算
acc = np.mean(cv_results['test_accuracy'])
rec = np.mean(cv_results['test_recall'])
f1 = np.mean(cv_results['test_f1'])
# 模型比较
if f1 > best_score:
best_score = f1
best_model_name = name
best_model = model
# 生成学习曲线
self.plot_learning_curve(model, X_scaled, y, name, output_dir)
评估策略:
-
使用3折交叉验证降低数据划分敏感性
-
以F1宏平均作为模型选择标准
-
同步输出各模型指标的标准差
2.3 学习曲线绘制 plot_learning_curve
2.3.1 数据计算
def plot_learning_curve(self, model, X, y, model_name, output_dir):
train_sizes, train_scores, test_scores = learning_curve(
model,
X,
y,
cv=3, # 3折交叉验证
scoring='accuracy', # 使用准确率指标
n_jobs=-1 # 并行计算
)
# 统计量计算
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)
2.3.2 可视化实现
plt.figure(figsize=(8, 6))
plt.fill_between(
train_sizes,
train_mean - train_std,
train_mean + train_std,
alpha=0.1,
color="r"
)
plt.plot(train_sizes, train_mean, 'o-', color="r", label="Training score")
# 测试集曲线同理...
plt.title(f"Learning Curve ({model_name})")
plt.xlabel("Training Examples")
plt.ylabel("Accuracy Score")
plt.legend(loc="best")
# 保存图像
output_path = os.path.join(output_dir, f"{model_name}_learning_curve.png")
plt.savefig(output_path)
plt.close()
可视化分析:
-
阴影区域表示±1标准差范围
-
训练曲线(红色)与验证曲线(绿色)对比
-
图像尺寸设为8x6英寸保证可读性
三、使用流程示例
# 示例数据
X, y = load_your_data() # 需自定义数据加载方法
# 初始化训练器
trainer = ModelTrainer()
# 执行训练
best_model = trainer.train_model(
X,
y,
output_dir="my_models" # 指定输出目录
)
# 使用最佳模型预测
predictions = best_model.predict(new_data)
四、输出文件结构
model_plots/
├── Random Forest_learning_curve.png
├── Linear SVM_learning_curve.png
└── KNN_learning_curve.png
图像展示模型的学习过程,帮助诊断欠/过拟合问题