💕个人主页:Puszy-CSDN博客💕
🎉 欢迎关注:👍点赞 📣留言 😍收藏 💖 💖 💖
在心里种花,人生才不会荒芜。 –– 《人民日报》
💖引言
本文基于PyTorch框架,采用使用PyTorch的nn.Module模块定义多层感知机(MLP)模型实现MNIST手写数字识别,在GPU上运行,实现高达98%的测试准确率,并完整展示从数据加载到模型部署的全流程。
PyTorch的模型大致结构普遍相似,也可修改相关参数的更换为其他简单图像分类任务,实测在CIFAR10数据集任务上分类准确率在60%左右,较为复杂的图像需要考虑使用CNN框架,或者YOLO模型会处理较好,实测在使用Resnet50分类冻结部分层情况下CIFAR10任务精度可达到95%以上。
目录
📚一、MNIST数据集
🚀1.1 数据概述
MNIST数据集是深度学习领域中广泛使用的基准测试数据集,常被称为“Hello World”项目,是许多研究者和学习者进入深度学习领域的起点。MNIST数据集最初由Yann LeCun等研究人员开发,目的是为了方便测试不同的机器学习算法的性能,这个数据集主要用于手写数字识别任务,包含了大量的手写数字图像及其对应的标签。
🚀1.2 数据集构成
MNIST数据集包含60,000张训练图像和10,000张测试图像,每张图像是28x28像素的灰度图像,代表数字0至9。
📚二、训练准备
首先,代码导入了必要的Python库,如torch
、torchvision
、matplotlib
等,用于深度学习模型的构建、数据处理和可视化。接着,检查CUDA是否可用,如果可用则使用GPU进行计算,代码如下:
import torch
import time
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
🚀2.1 数据集准备
定义了数据预处理操作,将图像转换为张量,并设置了批量大小。通过torchvision.datasets.MNIST
下载MNIST训练集和测试集,并使用DataLoader
创建数据加载器,方便后续批量处理数据:
#加载数据
transform=transforms.Compose([
transforms.ToTensor(),
])
batch_size=64
#下载数据集
Train_datast = datasets.MNIST(root='./',train=True,download=True,transform=transform)
test_datast = datasets.MNIST(root='./',train=False,download=True,transform=transform)
train_loader=DataLoader(dataset=Train_datast,batch_size=batch_size,shuffle=True,drop_last=True)
test_loader=DataLoader(dataset=test_datast,batch_size=batch_size,shuffle=True,drop_last=True)
# 打印数据集信息
print(f"训练集大小: {len(Train_datast)}")
print(f"测试集大小: {len(test_datast)}")
🚀2.2 批次样本可视化
代码从训练数据加载器中取出一个批次的图像和标签,使用torchvision.utils.make_grid
函数将这些图像拼接成一个网格,并使用matplotlib
进行可视化展示,以直观地查看数据集的样本情况:
# 构建图像网格
batch_images, batch_labels = next(iter(train_loader))
grid_img = torchvision.utils.make_grid(
tensor=batch_images, # 输入张量 [B, C, H, W]
nrow=8, # 每行显示图片数(根据批次大小调整)
padding=2, # 图片间距
)
print(grid_img.shape)
print(grid_img.permute(1, 2, 0).shape)
# 可视化
plt.figure(figsize=(12, 6))
plt.imshow(grid_img.permute(1, 2, 0).squeeze(), cmap='gray') # 维度调整
plt.axis('off')
plt.title(f"Batch Size: {len(batch_images)} | Label Range: {batch_labels.min()}-{batch_labels.max()}")
plt.show()
运行结果如下
📚三、模型架构设计
🚀3.1 MLP网络结构
定义了一个多层感知机(MLP)类MLP
,继承自nn.Module
。该模型包含多个全连接层,中间使用了批量归一化(nn.BatchNorm1d
)、ReLU激活函数(nn.ReLU
)和Dropout正则化(nn.Dropout
),以提高模型的泛化能力。最后一层输出维度为10,对应MNIST数据集中的10个数字类别:
#定义网络架构
class MLP(nn.Module):
def __init__(self,hidden_size,height,width):
super(MLP,self).__init__()
self.layers = nn.Sequential(
nn.Flatten(),
nn.Linear(height*width, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size, hidden_size//2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size//2, hidden_size//4),
nn.ReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(hidden_size//4),
nn.Linear(hidden_size//4, 10),
)
def forward(self, x):
return self.layers(x)
# 训练参数
batch_num,height,width=batch_images[0].shape
model=MLP(hidden_size=4096,height=height,width=width).to(device)
print(batch_images[0].shape)
print(height*width)
torch.flatten(batch_images[0]).shape
📚四、模型训练模块
🚀4.1 训练配置
参数 | 值 | 作用说明 |
---|---|---|
优化器 | AdamW | 结合权重衰减的正则化优化 |
学习率 | 0.001 | 平衡收敛速度与稳定性 |
权重衰减 | 1e-4 | 控制模型复杂度 |
Early Stopping | 10轮 | 防止过拟合 |
🚀4.2 训练流程
设置训练参数,包括训练轮数(epochs
)、初始最优准确率(best_accuracy
)等。在每个训练轮次中,模型处于训练模式(model.train()
),遍历训练数据加载器,进行前向传播、损失计算(使用交叉熵损失函数criterion
)、反向传播和参数更新(使用优化器optimizer
,代码中未定义,需补充)。在验证阶段,模型处于评估模式(model.eval()
),遍历测试数据加载器,计算测试集的损失和准确率。同时,记录每个轮次的训练损失、测试损失、训练准确率和测试准确率,并更新最优模型:
# 训练参数
epochs = 100
best_accuracy = 0.0
correct_train=0.0
loss_train=0
loss_train_list=[]
loss_test_list=[]
train_acc_list = []
test_acc_list = []
for epoch in range(epochs):
since = time.time()
model.train()
print('-' * 20)
print(f'Epoch {epoch+1}/{epochs}')
total_loss_train=0
correct_train=0
k_test=0
k_train=0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(images)
_, predicted_train = torch.max(outputs, dim=1)
loss_train = criterion(outputs, labels)
correct_train += (predicted_train == labels).sum().item()
# 反向传播和优化
optimizer.zero_grad()
loss_train.backward()
optimizer.step()
total_loss_train+=loss_train.item()
k_train+=1
# 验证阶段
model.eval()
correct_test = 0
loss_test = []
total_loss_test=0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
values_test, predicted_test = torch.max(outputs, 1)
loss_test =criterion(outputs, labels)
total_loss_test+=loss_test.item()
correct_test += (predicted_test == labels).sum().item()
k_test+=1
#更新保存损失,准确率
current_train_acc=correct_train/(k_train*batch_size)
current_test_acc = correct_test/(k_test*batch_size)
train_acc_list.append(current_train_acc)
test_acc_list.append(current_test_acc)
loss_train_list.append(total_loss_train)
loss_test_list.append(total_loss_test)
# 实时更新最优模型
if current_test_acc >= best_accuracy:
best_accuracy = current_test_acc
epochs_since_improvement = 0
torch.save(model, 'best_model.pkl')
else:
epochs_since_improvement += 1
time_elapsed = time.time() - since
print(f'训练集交叉熵损失为:{total_loss_train} 测试集交叉熵损失为:{total_loss_test}')
print(f'训练集准确率为:{current_train_acc} 测试集准确率为:{current_test_acc}')
print('本轮用时:{:.0f}m {:.0f}s epochs_since_improvement:{}'.format(time_elapsed // 60, time_elapsed % 60,epochs_since_improvement))
patience=10
if epochs_since_improvement >= patience:
print('-'*20)
print('测试集损失在',patience, '轮内没有提升,训练结束')
print('停止在', epoch-patience)
print(f"最优准确率为:{best_accuracy}")
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s '.format(time_elapsed // 60, time_elapsed % 60))
break
如果在连续patience
轮中测试集准确率没有提升,则提前停止训练。
运行结果
📚五、可视化损失,准确度
使用matplotlib
创建复合图表,分别绘制训练过程中的损失曲线和准确率曲线,直观展示模型的训练效果和泛化能力:
import matplotlib.pyplot as plt
import numpy as np
# 配置全局绘图参数
plt.rcParams['font.sans-serif'] = ['SimHei'] # 支持中文显示
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示异常
# 使用专业主题
# 创建复合图表
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6), dpi=100)
# 子图1:损失曲线
ax1.plot(np.array(loss_test_list)/batch_size, 'orange', lw=1.5, alpha=0.8, label='测试损失')
ax1.set_title(' 损失曲线演进分析', fontsize=14, pad=12)
ax1.set_xlabel(' 训练轮次', fontsize=12)
ax1.set_ylabel(' 标准化损失值', fontsize=12)
ax1.legend(frameon=True, shadow=True, loc='upper right')
# 子图2:准确率曲线
ax2.plot(train_acc_list, 'b-', lw=1.5, alpha=0.8, label='训练准确率')
ax2.plot(test_acc_list, 'orange', lw=1.5, alpha=0.8, label='测试准确率')
ax2.set_title(' 准确率发展轨迹', fontsize=14, pad=12)
ax2.set_xlabel(' 训练轮次', fontsize=12)
ax2.set_ylabel(' 分类准确率', fontsize=12)
ax2.set_ylim(min(train_acc_list), 1) # 固定准确率范围
ax2.legend(frameon=True, shadow=True, loc='lower right')
# 添加全局注释
plt.suptitle(f'MNIST 分类器训练监控报告 (最优准确率: {best_accuracy:.2%})',
y=1.02, fontsize=16, color='#2c3e50')
# 保存与显示
plt.tight_layout()
plt.savefig('MNIST_training_analysis_MLP.png', bbox_inches='tight', dpi=300)
plt.show()
结果可视化,准确率在98.6%左右
📚六、可视化预测结果
加载保存的最优模型,从测试数据加载器中取出一个批次的图像和标签,使用模型进行预测。将预测结果和真实标签转换为NumPy数组,并使用matplotlib
创建图像网格,将预测结果和真实标签标注在每个图像上,直观展示模型的预测效果:
#加载模型产生预测结果
best_model=torch.load('./best_model.pkl',weights_only=False,map_location='cuda:0')
best_model.eval()
batch_images, batch_labels = next(iter(test_loader))
batch_images=batch_images.to(device)
values_test, predicted_test = torch.max(best_model(batch_images), 1)
#转移到cpu numpy格式
batch_labels=batch_labels.cpu().numpy()
predicted_test=predicted_test.cpu().numpy()
# print(f'实际标签:{batch_labels}')
# print(f'预测标签:{predicted_test}')
# 创建可视化图像网格
nrow = 16# 与make_grid中的nrow一致
ncol = batch_size // nrow # 计算列数
fig, axes = plt.subplots(nrow, ncol, figsize=(15, 15))
axes = axes.flatten()
for i in range(batch_size):
# 处理图像维度 [C, H, W] -> [H, W]
img = batch_images[i].cpu().permute(1, 2, 0).squeeze().numpy()
#显示图像
axes[i].imshow(img, cmap='gray')
#隐藏坐标轴
axes[i].axis('off')
# 设置标题颜色和内容
pred = predicted_test[i]
true = batch_labels[i]
color = 'green' if pred == true else 'red'
axes[i].set_title(f'P:{pred} | T:{true}', color=color, fontsize=8)
# 隐藏多余的子图(如果有)
for j in range(batch_size, len(axes)):
axes[j].axis('off')
plt.tight_layout()
plt.suptitle(f"Batch Size: {batch_size} | Label Range: {batch_labels.min()}-{batch_labels.max()}", y=1.02)
plt.show()
运行结果
-
绿色标注表示预测正确
-
红色标注表示预测错误
-
典型错误案例:书写模糊/非常规写法