总结卷积神经网络搭建中遇到的一些问题:
一、transform应用在数据下载阶段(datasets)而不是数据载入阶段(DataLoader)
二、进入全连接层之前没有进行正确的矩阵变形
三、训练前没有进行梯度清零
四、测试集预测值没有使用torch.max()变换为预测标签,导致尺寸不匹配
代码展示:
import torch
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
batch_size=50
iteration=5
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.15,0.30)])
train_set=datasets.MNIST("D:\桌面\CNN",train=True,download=True,transform=trans)
train_loader=DataLoader(dataset=train_set,batch_size=batch_size,shuffle=True)
test_set=datasets.MNIST("D\桌面\CNN",train=False,transform=trans,download=True)
test_loader=DataLoader(test_set,batch_size=batch_size)
class CNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1=torch.nn.Conv2d(1,10,kernel_size=5)
self.conv2=torch.nn.Conv2d(10,20,kernel_size=5)
self.pool1=torch.nn.MaxPool2d(2)
self.linear1=torch.nn.Linear(320,10)
def forward(self,x):
x=self.pool1(self.conv1(x))
x=self.pool1(self.conv2(x))
x=x.view(-1,320)
x=self.linear1(x)
return x
model=CNN()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
schedular=torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,gamma=0.999)
def train():
for epoch in range(iteration):
for i,data in enumerate(train_loader,0):
l=0.0
train_data,train_label=data
optimizer.zero_grad()
pred_set=model(train_data)
loss=criterion(pred_set,train_label)
l+=loss
loss.backward()
optimizer.step()
schedular.step()
if(i%100==0):
print("epoch:",epoch,"batch_sequence:",i/100,"loss:",l.data.item())
def test():
with torch.no_grad():
total=0.0
correct=0.0
for batch_index,data in enumerate(test_loader,0):
test_data,test_label=data
pred_data=model(test_data)
#print("pred_shape:",pred_label.shape,"test_label:",test_label.shape)
_,pred_label=torch.max(pred_data,dim=1)
total+=pred_label.shape[0]
correct+=(pred_label==test_label).sum().item()
print("准确率为:",100.0*correct/total,"%")
train()
test()
预测结果:
#准确率在98.52%,可以增加迭代次数,进一步提高准确率。
模型改进空间:
一、测试集的预测可以考虑不使用Mini-batch
二、从测试集划分出开发集/验证集,对模型进行优化
三、增加迭代次数,进行更好的拟合
四、增加隐藏层的数量,使用更多的卷积层/池化层/全连接层