class BPNetwork(torch.nn.Module):
def __init__(self):
super(BPNetwork, self).__init__()
"""
定义第一个线性层,
输入为图片(28x28),
输出为第一个隐层的输入,大小为128。
"""
self.linear1 = torch.nn.Linear(28 * 28, 128)
# 在第一个隐层使用ReLU激活函数
self.relu1 = torch.nn.ReLU()
"""
定义第二个线性层,
输入是第一个隐层的输出,
输出为第二个隐层的输入,大小为64。
"""
self.linear2 = torch.nn.Linear(128, 64)
# 在第二个隐层使用ReLU激活函数
self.relu2 = torch.nn.ReLU()
"""
定义第三个线性层,
输入是第二个隐层的输出,
输出为输出层,大小为10
"""
self.linear3 = torch.nn.Linear(64, 10)
# 最终的输出经过softmax进行归一化
self.softmax = torch.nn.LogSoftmax(dim=1)
def forward(self, x):
"""
定义神经网络的前向传播
x: 图片数据, shape为(64, 1, 28, 28)
"""
# 首先将x的shape转为(64, 784)
x = x.view(x.shape[0], -1)
# 接下来进行前向传播
x = self.linear1(x)
x = self.relu1(x)
x = self.linear2(x)
x = self.relu2(x)
x = self.linear3(x)
x = self.softmax(x)
# 上述一串,可以直接使用 x = self.model(x) 代替。
return x
定义了一个简单的神经网络模型,该模型继承自PyTorch的torch.nn.Module
。这个模型用于处理图像数据,并最后输出一个经过softmax归一化的10维向量。下面是对代码的详细解释:
-
初始化函数
__init__
:self.linear1 = torch.nn.Linear(28 * 28, 128)
: 定义第一个线性层(全连接层)。输入是28x28的图像数据(即784个像素值),输出是一个128维的向量。self.relu1 = torch.nn.ReLU()
: 在第一个隐层后使用ReLU激活函数。self.linear2 = torch.nn.Linear(128, 64)
: 定义第二个线性层。输入是第一个隐层的128维输出,输出是一个64维的向量。self.relu2 = torch.nn.ReLU()
: 在第二个隐层后使用ReLU激活函数。self.linear3 = torch.nn.Linear(64, 10)
: 定义第三个线性层。输入是第二个隐层的64维输出,输出是一个10维的向量。self.softmax = torch.nn.LogSoftmax(dim=1)
: 在输出层使用LogSoftmax进行归一化。dim=1
表示在第二个维度(即每一行)上进行归一化。
-
前向传播函数
forward
:- 输入
x
是一个四维张量,其shape为(64, 1, 28, 28)
。这表示有64张图像,每张图像为单通道(灰度图),大小为28x28。 x = x.view(x.shape[0], -1)
: 这行代码将输入的四维张量转换为一个二维张量,其shape为(64, 784)
。这样做的目的是将每张图像展平为一个784维的向量,以便可以输入到第一个线性层。- 接下来的几行
x = self.layer(x)
是对网络进行前向传播。数据依次通过线性层、ReLU激活函数,直到最后的LogSoftmax归一化。 - 最后,函数返回经过LogSoftmax归一化的输出
x
。
- 输入
- 在
forward
函数注释中提到“上述一串,可以直接使用 x = self.model(x) 代替。”,但实际上在代码中并没有定义self.model
。这可能是一个错误或者是一个未完成的代码片段。如果你想使用这种方式来简化代码,你需要先定义一个包含所有层的Sequential
模型,并将其赋值给self.model
。
整体来说,这是一个非常基础的神经网络模型,适用于处理简单的图像分类任务。
# 第2步:加载数据集
# MNIST包含70,000张手写数字图像: 60,000张用于训练,10,000张用于测试。
# 图像是灰度的,28×28像素的,并且居中的,以减少预处理和加快运行。
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 使用torchvision读取数据
trainset = torchvision.datasets