import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 设置matplotlib支持中文的字体
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 例如使用微软雅黑
plt.rcParams['axes.unicode_minus'] = False # 解决负号无法显示的问题
# 定义 LeNet 模型
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.pool(nn.functional.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return x
# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
# 加载数据集
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=256, shuffle=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False)
# 实例化模型、损失函数和优化器
model = LeNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 初始化列表以跟踪每个周期后的损失和准确率
train_losses = []
test_accuracy = []
# 训练模型
for epoch in range(5): # 训练 5 个周期
model.train()
running_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
# 测试模型准确率
model.eval()
correct = 0
total = 0
all_labels = []
all_predicted = []
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_labels.extend(labels.cpu().numpy())
all_predicted.extend(predicted.cpu().numpy())
epoch_accuracy = 100 * correct / total
test_accuracy.append(epoch_accuracy)
# 输出训练结果
print(f'Epoch [{epoch+1}/5], Loss: {avg_train_loss:.4f}, 准确率: {epoch_accuracy:.2f}%')
# 绘制训练损失图
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='训练损失')
plt.title('训练损失')
plt.xlabel('周期数')
plt.ylabel('损失')
plt.legend()
plt.show()
# 绘制测试准确率图
plt.figure(figsize=(10, 5))
plt.plot(test_accuracy, label='测试准确率')
plt.title('测试准确率')
plt.xlabel('周期数')
plt.ylabel('准确率 (%)')
plt.legend()
plt.show()
# 绘制混淆矩阵
conf_mat = confusion_matrix(all_labels, all_predicted)
plt.figure(figsize=(10, 10))
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues')
plt.xlabel('预测')
plt.ylabel('实际')
plt.title('混淆矩阵')
plt.show()