机器学习模型训练模块技术文档

一、模块结构概览

import numpy as np
from sklearn.model_selection import cross_validate, learning_curve
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, accuracy_score, recall_score, f1_score
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
import os

依赖说明

  • numpy:处理数值计算

  • sklearn:提供机器学习算法和工具

  • matplotlib:可视化学习曲线

  • os:处理文件路径操作

二、核心类定义

2.1 类初始化

class ModelTrainer:
    def __init__(self):
        pass

功能:创建模型训练器的基础类,当前无需特殊初始化参数 

 

2.2 主训练方法 train_model

2.2.1 数据准备阶段
def train_model(self, X, y, output_dir="model_plots"):
    # 创建输出文件夹
    os.makedirs(output_dir, exist_ok=True)
    
    # 数据分割
    X_train, X_test, y_train, y_test = train_test_split(
        X, y,
        test_size=0.2,        # 20%测试集
        stratify=y,           # 保持类别分布
        random_state=42       # 可重复性种子
    )
    
    # 数据标准化
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)  # 训练集拟合+转换
    X_test_scaled = scaler.transform(X_test)        # 测试集仅转换
    
    # 合并标准化数据
    X_scaled = np.concatenate([X_train_scaled, X_test_scaled])
    y = np.concatenate([y_train, y_test])

关键技术点

  • stratify=y 保证分割后的数据保持原始类别分布

  • 标准化处理防止特征尺度差异影响模型性能

  • 合并数据集用于交叉验证

2.2.2 模型配置
models = {
    "Random Forest": RandomForestClassifier(
        n_estimators=200,  # 增加树数量提升模型容量
        max_depth=8,        # 限制深度防止过拟合
        n_jobs=-1          # 使用全部CPU核心
    ),
    "Linear SVM": SVC(
        kernel='rbf',       # 选择径向基函数核
        C=0.5,             # 正则化强度参数
        gamma='auto',      # 自动计算gamma参数
        probability=True   # 启用概率估计
    ),
    "KNN": KNeighborsClassifier(
        n_neighbors=3,     # 使用3近邻
        n_jobs=-1          # 并行计算
    )
}

scoring = {
    'accuracy': make_scorer(accuracy_score),
    'recall': make_scorer(recall_score, average='macro'),  # 多分类宏平均
    'f1': make_scorer(f1_score, average='macro')
}

参数调优说明

  • 随机森林:通过限制max_depth平衡偏差-方差

  • SVM:调整C值控制正则化强度

  • KNN:小邻域数适合高维度数据

2.2.3 交叉验证流程
best_score = -1
best_model_name = ""
best_model = None

for name, model in models.items():
    # 交叉验证
    cv_results = cross_validate(
        model, 
        X_scaled, 
        y, 
        cv=3,              # 3折交叉验证
        scoring=scoring    # 使用自定义指标
    )
    
    # 指标计算
    acc = np.mean(cv_results['test_accuracy'])
    rec = np.mean(cv_results['test_recall'])
    f1 = np.mean(cv_results['test_f1'])
    
    # 模型比较
    if f1 > best_score:
        best_score = f1
        best_model_name = name
        best_model = model
    
    # 生成学习曲线
    self.plot_learning_curve(model, X_scaled, y, name, output_dir)

评估策略

  • 使用3折交叉验证降低数据划分敏感性

  • 以F1宏平均作为模型选择标准

  • 同步输出各模型指标的标准差

2.3 学习曲线绘制 plot_learning_curve

2.3.1 数据计算

def plot_learning_curve(self, model, X, y, model_name, output_dir):
    train_sizes, train_scores, test_scores = learning_curve(
        model, 
        X, 
        y, 
        cv=3,               # 3折交叉验证
        scoring='accuracy', # 使用准确率指标
        n_jobs=-1          # 并行计算
    )
    
    # 统计量计算
    train_mean = np.mean(train_scores, axis=1)
    train_std = np.std(train_scores, axis=1)
    test_mean = np.mean(test_scores, axis=1)
    test_std = np.std(test_scores, axis=1)
2.3.2 可视化实现
    plt.figure(figsize=(8, 6))
    plt.fill_between(
        train_sizes,
        train_mean - train_std,
        train_mean + train_std,
        alpha=0.1, 
        color="r"
    )
    plt.plot(train_sizes, train_mean, 'o-', color="r", label="Training score")
    # 测试集曲线同理...
    
    plt.title(f"Learning Curve ({model_name})")
    plt.xlabel("Training Examples")
    plt.ylabel("Accuracy Score")
    plt.legend(loc="best")
    
    # 保存图像
    output_path = os.path.join(output_dir, f"{model_name}_learning_curve.png")
    plt.savefig(output_path)
    plt.close()

可视化分析

  • 阴影区域表示±1标准差范围

  • 训练曲线(红色)与验证曲线(绿色)对比

  • 图像尺寸设为8x6英寸保证可读性

三、使用流程示例

# 示例数据
X, y = load_your_data()  # 需自定义数据加载方法

# 初始化训练器
trainer = ModelTrainer()

# 执行训练
best_model = trainer.train_model(
    X, 
    y,
    output_dir="my_models"  # 指定输出目录
)

# 使用最佳模型预测
predictions = best_model.predict(new_data)

四、输出文件结构


model_plots/
├── Random Forest_learning_curve.png
├── Linear SVM_learning_curve.png
└── KNN_learning_curve.png

图像展示模型的学习过程,帮助诊断欠/过拟合问题

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

happydog007

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值