Pytorch卷积神经网络:基于MNIST数据集的手写数字识别

本文介绍了卷积神经网络(CNN)搭建过程中遇到的问题,包括transform应用时机、矩阵变形、梯度清零以及测试集预测值处理。通过修正这些问题,模型在MNIST数据集上达到了98.52%的准确率。为了进一步提升模型性能,提出了不使用Mini-batch预测、划分验证集、增加迭代次数和调整网络结构等优化建议。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

总结卷积神经网络搭建中遇到的一些问题:

一、transform应用在数据下载阶段(datasets)而不是数据载入阶段(DataLoader)
二、进入全连接层之前没有进行正确的矩阵变形
三、训练前没有进行梯度清零
四、测试集预测值没有使用torch.max()变换为预测标签,导致尺寸不匹配

代码展示:

import torch
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

batch_size=50
iteration=5
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.15,0.30)])
train_set=datasets.MNIST("D:\桌面\CNN",train=True,download=True,transform=trans)
train_loader=DataLoader(dataset=train_set,batch_size=batch_size,shuffle=True)
test_set=datasets.MNIST("D\桌面\CNN",train=False,transform=trans,download=True)
test_loader=DataLoader(test_set,batch_size=batch_size)

class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=torch.nn.Conv2d(1,10,kernel_size=5)
        self.conv2=torch.nn.Conv2d(10,20,kernel_size=5)
        self.pool1=torch.nn.MaxPool2d(2)
        self.linear1=torch.nn.Linear(320,10)

    def forward(self,x):
        x=self.pool1(self.conv1(x))
        x=self.pool1(self.conv2(x))
        x=x.view(-1,320)
        x=self.linear1(x)
        return x

model=CNN()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
schedular=torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,gamma=0.999)

def train():
    for epoch in range(iteration):
        for i,data in enumerate(train_loader,0):
            l=0.0
            train_data,train_label=data
            optimizer.zero_grad()
            pred_set=model(train_data)
            loss=criterion(pred_set,train_label)
            l+=loss
            loss.backward()
            optimizer.step()
            schedular.step()
            if(i%100==0):
                print("epoch:",epoch,"batch_sequence:",i/100,"loss:",l.data.item())

def test():
    with torch.no_grad():
        total=0.0
        correct=0.0
        for batch_index,data in enumerate(test_loader,0):
            test_data,test_label=data
            pred_data=model(test_data)
            #print("pred_shape:",pred_label.shape,"test_label:",test_label.shape)
            _,pred_label=torch.max(pred_data,dim=1)
            total+=pred_label.shape[0]
            correct+=(pred_label==test_label).sum().item()
        print("准确率为:",100.0*correct/total,"%")

train()
test()


预测结果:

在这里插入图片描述
#准确率在98.52%,可以增加迭代次数,进一步提高准确率。

模型改进空间:

一、测试集的预测可以考虑不使用Mini-batch
二、从测试集划分出开发集/验证集,对模型进行优化
三、增加迭代次数,进行更好的拟合
四、增加隐藏层的数量,使用更多的卷积层/池化层/全连接层

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值