总体代码:
以下是一个结构清晰、易于维护和复用的 PyTorch 神经网络完整实现代码,包含了模块化设计、配置管理和常用功能,可直接用于图像分类任务并轻松扩展到其他场景。
模块化PyTorch神经网络完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from typing import List, Tuple, Optional, Dict, Any
# ---------------------------
# 1. 配置管理模块
# ---------------------------
class NetworkConfig:
"""网络配置参数管理类,集中存储和验证所有超参数"""
def __init__(self,
input_shape: Tuple[int, int, int] = (3, 32, 32),
num_classes: int = 10,
conv_layers: List[Dict[str, Any]] = None,
fc_layers: List[Dict[str, Any]] = None,
activation: str = 'relu',
use_batchnorm: bool = True,
initializer: str = 'kaiming'):
"""
初始化网络配置
参数:
input_shape: 输入数据形状 (通道数, 高度, 宽度)
num_classes: 分类类别数量
conv_layers: 卷积层配置列表,每个元素为包含卷积层参数的字典
fc_layers: 全连接层配置列表,每个元素为包含全连接层参数的字典
activation: 激活函数类型 ('relu', 'leaky_relu', 'elu', 'sigmoid')
use_batchnorm: 是否使用批归一化
initializer: 权重初始化方法 ('kaiming', 'xavier', 'normal')
"""
# 设置默认卷积层配置
if conv_layers is None:
conv_layers = [
{'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 64, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 128, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2}
]
# 设置默认全连接层配置
if fc_layers is None:
fc_layers = [
{'units': 512, 'dropout': 0.5},
]
self.input_shape = input_shape
self.num_classes = num_classes
self.conv_layers = conv_layers
self.fc_layers = fc_layers
self.activation = activation
self.use_batchnorm = use_batchnorm
self.initializer = initializer
# 验证配置有效性
self._validate_config()
def _validate_config(self):
"""验证配置参数的有效性"""
# 验证输入形状
assert len(self.input_shape) == 3, "输入形状必须是 (通道数, 高度, 宽度)"
assert all(isinstance(x, int) and x > 0 for x in self.input_shape), "输入形状必须为正整数"
# 验证卷积层配置
for i, layer in enumerate(self.conv_layers):
required_keys = ['out_channels', 'kernel_size', 'stride', 'padding', 'pool_size']
assert all(key in layer for key in required_keys), f"卷积层 {i} 缺少必要参数"
assert layer['out_channels'] > 0, f"卷积层 {i} 输出通道数必须为正数"
assert layer['kernel_size'] % 2 == 1, f"卷积层 {i} kernel大小必须为奇数"
# 验证全连接层配置
for i, layer in enumerate(self.fc_layers):
assert 'units' in layer and layer['units'] > 0, f"全连接层 {i} 单元数必须为正数"
assert 'dropout' in layer and 0 <= layer['dropout'] < 1, f"全连接层 {i} dropout必须在[0,1)范围内"
# 验证激活函数
valid_activations = ['relu', 'leaky_relu', 'elu', 'sigmoid']
assert self.activation in valid_activations, f"激活函数必须是 {valid_activations} 之一"
# 验证初始化方法
valid_initializers = ['kaiming', 'xavier', 'normal']
assert self.initializer in valid_initializers, f"初始化方法必须是 {valid_initializers} 之一"
def __repr__(self):
"""打印配置信息"""
return (f"NetworkConfig(input_shape={self.input_shape}, num_classes={self.num_classes}, "
f"conv_layers={len(self.conv_layers)}, fc_layers={len(self.fc_layers)}, "
f"activation={self.activation}, use_batchnorm={self.use_batchnorm})")
# ---------------------------
# 2. 基础组件模块
# ---------------------------
class ConvBlock(nn.Module):
"""卷积块组件: 卷积 -> 批归一化 -> 激活 -> 池化"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 1,
pool_size: int = 2,
activation: str = 'relu',
use_batchnorm: bool = True):
super().__init__()
# 构建卷积层
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
# 批归一化层(可选)
self.bn = nn.BatchNorm2d(out_channels) if use_batchnorm else None
# 激活函数
self.activation = self._get_activation(activation)
# 池化层
self.pool = nn.MaxPool2d(kernel_size=pool_size, stride=pool_size)
# 存储输出通道数,便于后续层计算
self.out_channels = out_channels
def _get_activation(self, activation: str) -> nn.Module:
"""根据名称返回对应的激活函数"""
activations = {
'relu': nn.ReLU(inplace=True),
'leaky_relu': nn.LeakyReLU(0.1, inplace=True),
'elu': nn.ELU(inplace=True),
'sigmoid': nn.Sigmoid()
}
return activations[activation]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播"""
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
x = self.activation(x)
x = self.pool(x)
return x
def __repr__(self):
return f"ConvBlock(in_channels={self.conv.in_channels}, out_channels={self.conv.out_channels}, " \
f"kernel_size={self.conv.kernel_size}, pool_size={self.pool.kernel_size})"
class FcBlock(nn.Module):
"""全连接块组件: 线性层 -> 激活 -> Dropout"""
def __init__(self,
in_features: int,
out_features: int,
activation: str = 'relu',
dropout: float = 0.0):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.activation = self._get_activation(activation)
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
self.out_features = out_features
def _get_activation(self, activation: str) -> nn.Module:
"""复用卷积块中的激活函数逻辑"""
activations = {
'relu': nn.ReLU(inplace=True),
'leaky_relu': nn.LeakyReLU(0.1, inplace=True),
'elu': nn.ELU(inplace=True),
'sigmoid': nn.Sigmoid()
}
return activations[activation]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播"""
x = self.fc(x)
x = self.activation(x)
if self.dropout is not None:
x = self.dropout(x)
return x
def __repr__(self):
return f"FcBlock(in_features={self.fc.in_features}, out_features={self.fc.out_features}, " \
f"dropout={self.dropout.p} if self.dropout else 0)"
# ---------------------------
# 3. 主网络模块
# ---------------------------
class ImageClassifier(nn.Module):
"""通用图像分类器,可通过配置灵活调整结构"""
def __init__(self, config: NetworkConfig):
super().__init__()
self.config = config
# 构建网络层
self.features = self._build_features() # 特征提取部分(卷积层)
self.classifier = self._build_classifier() # 分类部分(全连接层)
# 初始化权重
self._initialize_weights()
def _build_features(self) -> nn.Sequential:
"""构建特征提取部分(卷积层)"""
conv_blocks = []
in_channels = self.config.input_shape[0] # 输入通道数
for layer_cfg in self.config.conv_layers:
conv_block = ConvBlock(
in_channels=in_channels,
out_channels=layer_cfg['out_channels'],
kernel_size=layer_cfg['kernel_size'],
stride=layer_cfg['stride'],
padding=layer_cfg['padding'],
pool_size=layer_cfg['pool_size'],
activation=self.config.activation,
use_batchnorm=self.config.use_batchnorm
)
conv_blocks.append(conv_block)
in_channels = layer_cfg['out_channels'] # 更新输入通道数
return nn.Sequential(*conv_blocks)
def _build_classifier(self) -> nn.Sequential:
"""构建分类部分(全连接层)"""
fc_blocks = []
# 计算卷积层输出特征的平坦化尺寸
with torch.no_grad():
dummy_input = torch.zeros(1, *self.config.input_shape)
features = self.features(dummy_input)
in_features = features.view(1, -1).size(1)
# 构建全连接层
for layer_cfg in self.config.fc_layers:
fc_block = FcBlock(
in_features=in_features,
out_features=layer_cfg['units'],
activation=self.config.activation,
dropout=layer_cfg['dropout']
)
fc_blocks.append(fc_block)
in_features = layer_cfg['units']
# 添加最终输出层
fc_blocks.append(nn.Linear(in_features, self.config.num_classes))
return nn.Sequential(*fc_blocks)
def _initialize_weights(self):
"""根据配置初始化网络权重"""
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
if self.config.initializer == 'kaiming':
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=self.config.activation)
elif self.config.initializer == 'xavier':
nn.init.xavier_normal_(m.weight)
else: # normal
nn.init.normal_(m.weight, mean=0, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播"""
x = self.features(x) # 特征提取
x = x.view(x.size(0), -1) # 展平特征图
x = self.classifier(x) # 分类预测
return x
def get_layer_output(self, x: torch.Tensor, layer_name: str) -> torch.Tensor:
"""获取指定层的输出(用于调试和可视化)"""
# 处理特征提取部分
if layer_name.startswith('conv_'):
index = int(layer_name.split('_')[1]) - 1
for i, layer in enumerate(self.features):
x = layer(x)
if i == index:
return x
raise ValueError(f"卷积层 {layer_name} 不存在")
# 处理分类部分
elif layer_name.startswith('fc_'):
index = int(layer_name.split('_')[1]) - 1
x = self.features(x)
x = x.view(x.size(0), -1)
for i, layer in enumerate(self.classifier):
x = layer(x)
if i == index:
return x
raise ValueError(f"全连接层 {layer_name} 不存在")
else:
raise ValueError(f"未知层名称: {layer_name}")
# ---------------------------
# 4. 网络工厂与工具函数
# ---------------------------
class NetworkFactory:
"""网络工厂类,用于创建不同配置的网络实例"""
@staticmethod
def create_cifar10_model() -> ImageClassifier:
"""创建适用于CIFAR-10数据集的模型"""
config = NetworkConfig(
input_shape=(3, 32, 32),
num_classes=10,
activation='relu',
use_batchnorm=True
)
return ImageClassifier(config)
@staticmethod
def create_mnist_model() -> ImageClassifier:
"""创建适用于MNIST数据集的模型"""
conv_layers = [
{'out_channels': 16, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2}
]
fc_layers = [
{'units': 128, 'dropout': 0.3},
]
config = NetworkConfig(
input_shape=(1, 28, 28),
num_classes=10,
conv_layers=conv_layers,
fc_layers=fc_layers,
activation='relu',
use_batchnorm=True
)
return ImageClassifier(config)
@staticmethod
def create_custom_model(config_params: Dict[str, Any]) -> ImageClassifier:
"""根据自定义参数创建模型"""
config = NetworkConfig(** config_params)
return ImageClassifier(config)
def test_model():
"""测试模型功能"""
# 创建CIFAR-10模型并测试
print("=== 测试CIFAR-10模型 ===")
model = NetworkFactory.create_cifar10_model()
print("模型配置:", model.config)
# 打印模型结构摘要
print("\n模型结构摘要:")
summary(model, input_size=model.config.input_shape, device='cpu')
# 测试前向传播
test_input = torch.randn(8, *model.config.input_shape) # 8个样本
output = model(test_input)
print("\n前向传播测试:")
print(f"输入形状: {test_input.shape}")
print(f"输出形状: {output.shape}") # 应输出 (8, 10)
# 测试获取中间层输出
conv1_output = model.get_layer_output(test_input, 'conv_1')
fc1_output = model.get_layer_output(test_input, 'fc_1')
print(f"conv_1输出形状: {conv1_output.shape}")
print(f"fc_1输出形状: {fc1_output.shape}")
# 测试MNIST模型
print("\n=== 测试MNIST模型 ===")
mnist_model = NetworkFactory.create_mnist_model()
summary(mnist_model, input_size=mnist_model.config.input_shape, device='cpu')
# ---------------------------
# 主程序入口
# ---------------------------
if __name__ == "__main__":
# 测试模型功能
test_model()
# 演示如何创建自定义模型
print("\n=== 自定义模型示例 ===")
custom_params = {
"input_shape": (3, 64, 64),
"num_classes": 100,
"conv_layers": [
{'out_channels': 64, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 128, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 256, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 512, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2}
],
"fc_layers": [
{'units': 1024, 'dropout': 0.5},
{'units': 512, 'dropout': 0.3}
],
"activation": "leaky_relu",
"use_batchnorm": True
}
custom_model = NetworkFactory.create_custom_model(custom_params)
summary(custom_model, input_size=custom_params["input_shape"], device='cpu')
代码结构说明
这个实现采用了清晰的模块化设计,主要包含以下几个部分:
-
配置管理模块(NetworkConfig)
- 集中存储所有网络参数,避免硬编码
- 包含参数验证功能,防止配置错误
- 支持灵活定义卷积层和全连接层结构
-
基础组件模块
ConvBlock
:封装了卷积层、批归一化、激活函数和池化层FcBlock
:封装了全连接层、激活函数和 dropout 层- 每个组件专注于单一功能,易于测试和复用
-
主网络模块(ImageClassifier)
- 通过组合基础组件构建完整网络
- 自动计算各层输入输出尺寸,无需手动计算
- 包含权重初始化方法,采用最佳实践
- 提供获取中间层输出的功能,便于调试和可视化
-
网络工厂(NetworkFactory)
- 提供预定义模型(CIFAR-10、MNIST)的快捷创建方式
- 支持通过自定义参数创建新模型
- 简化相似网络的创建和管理过程
使用方法
- 创建预定义模型
# 创建适用于CIFAR-10的模型
cifar_model = NetworkFactory.create_cifar10_model()
# 创建适用于MNIST的模型
mnist_model = NetworkFactory.create_mnist_model()
- 创建自定义模型
custom_params = {
"input_shape": (3, 64, 64), # 3通道64x64图像
"num_classes": 100, # 100个类别
"activation": "leaky_relu", # 使用LeakyReLU激活函数
# 其他参数...
}
custom_model = NetworkFactory.create_custom_model(custom_params)
- 获取中间层输出(用于调试 / 可视化)
# 获取第一个卷积层输出
conv1_output = model.get_layer_output(input_tensor, 'conv_1')
# 获取第一个全连接层输出
fc1_output = model.get_layer_output(input_tensor, 'fc_1')
这种设计的优势在于:当需要修改网络结构时,只需调整配置参数而无需修改核心代码;新增功能时可以通过扩展基础组件实现,不影响现有代码;同时清晰的结构使团队协作和后期维护更加容易。
PyTorch 模块化神经网络代码逐句解释
下面我会对之前的模块化神经网络代码进行逐句解释,用通俗易懂的语言帮助理解每一部分的作用。
1. 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from typing import List, Tuple, Optional, Dict, Any
import torch
: 导入 PyTorch 库,深度学习的基础库import torch.nn as nn
: 导入 PyTorch 的神经网络模块,包含各种层和激活函数import torch.nn.functional as F
: 导入 PyTorch 的函数式接口,包含常用函数from torchsummary import summary
: 导入 summary 工具,用于打印网络结构摘要from typing ...
: 导入类型提示工具,让代码更易读,明确参数和返回值类型
2. 配置管理模块 (NetworkConfig 类)
class NetworkConfig:
"""网络配置参数管理类,集中存储和验证所有超参数"""
定义一个配置类,用来集中管理所有网络参数,避免参数散落在代码各处。
def __init__(self,
input_shape: Tuple[int, int, int] = (3, 32, 32),
num_classes: int = 10,
conv_layers: List[Dict[str, Any]] = None,
fc_layers: List[Dict[str, Any]] = None,
activation: str = 'relu',
use_batchnorm: bool = True,
initializer: str = 'kaiming'):
构造函数,初始化配置参数:
input_shape
: 输入图像的形状,默认是 (3, 32, 32) 表示 3 通道 32x32 的图像num_classes
: 分类的类别数量,默认 10 类conv_layers
: 卷积层的配置列表,每个元素是一个字典fc_layers
: 全连接层的配置列表activation
: 激活函数类型,默认 ReLUuse_batchnorm
: 是否使用批归一化,默认使用initializer
: 权重初始化方法,默认 Kaiming 初始化
# 设置默认卷积层配置
if conv_layers is None:
conv_layers = [
{'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 64, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2},
{'out_channels': 128, 'kernel_size': 3, 'stride': 1, 'padding': 1, 'pool_size': 2}
]
如果用户没有提供卷积层配置,就使用默认的 3 层卷积配置:
- 第一层输出 32 通道,3x3 卷积核,步长 1,填充 1,池化大小 2
- 第二层输出 64 通道,其他参数相同
- 第三层输出 128 通道,其他参数相同
# 设置默认全连接层配置
if fc_layers is None:
fc_layers = [
{'units': 512, 'dropout': 0.5},
]
如果用户没有提供全连接层配置,使用默认的 1 层全连接配置:
- 512 个神经元,50% 的 dropout 比率
self.input_shape = input_shape
self.num_classes = num_classes
self.conv_layers = conv_layers
self.fc_layers = fc_layers
self.activation = activation
self.use_batchnorm = use_batchnorm
self.initializer = initializer
将参数保存为类的属性,方便后续使用。
# 验证配置有效性
self._validate_config()
调用内部方法验证配置是否合理,防止错误的参数设置。
def _validate_config(self):
"""验证配置参数的有效性"""
# 验证输入形状
assert len(self.input_shape) == 3, "输入形状必须是 (通道数, 高度, 宽度)"
assert all(isinstance(x, int) and x > 0 for x in self.input_shape), "输入形状必须为正整数"
_validate_config
是内部方法,用来检查参数是否有效:
- 输入形状必须是 3 个元素 (通道数,高度,宽度)
- 输入形状的元素必须是正整数
# 验证卷积层配置
for i, layer in enumerate(self.conv_layers):
required_keys = ['out_channels', 'kernel_size', 'stride', 'padding', 'pool_size']
assert all(key in layer for key in required_keys), f"卷积层 {i} 缺少必要参数"
assert layer['out_channels'] > 0, f"卷积层 {i} 输出通道数必须为正数"
assert layer['kernel_size'] % 2 == 1, f"卷积层 {i} kernel大小必须为奇数"
检查每个卷积层配置:
- 必须包含所有必要的参数(输出通道数、卷积核大小等)
- 输出通道数必须是正数
- 卷积核大小必须是奇数(确保对称)
# 验证全连接层配置
for i, layer in enumerate(self.fc_layers):
assert 'units' in layer and layer['units'] > 0, f"全连接层 {i} 单元数必须为正数"
assert 'dropout' in layer and 0 <= layer['dropout'] < 1, f"全连接层 {i} dropout必须在[0,1)范围内"
检查每个全连接层配置:
- 必须包含神经元数量,且为正数
- 必须包含 dropout 比率,且在 0 到 1 之间
# 验证激活函数
valid_activations = ['relu', 'leaky_relu', 'elu', 'sigmoid']
assert self.activation in valid_activations, f"激活函数必须是 {valid_activations} 之一"
确保激活函数是支持的类型之一。
# 验证初始化方法
valid_initializers = ['kaiming', 'xavier', 'normal']
assert self.initializer in valid_initializers, f"初始化方法必须是 {valid_initializers} 之一"
确保权重初始化方法是支持的类型之一。
def __repr__(self):
"""打印配置信息"""
return (f"NetworkConfig(input_shape={self.input_shape}, num_classes={self.num_classes}, "
f"conv_layers={len(self.conv_layers)}, fc_layers={len(self.fc_layers)}, "
f"activation={self.activation}, use_batchnorm={self.use_batchnorm})")
定义打印对象时显示的信息,方便查看配置摘要。
3. 基础组件模块 - 卷积块 (ConvBlock 类)
class ConvBlock(nn.Module):
"""卷积块组件: 卷积 -> 批归一化 -> 激活 -> 池化"""
定义卷积块类,继承自 PyTorch 的 Module 类,包含一组完整的卷积操作流程。
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 1,
pool_size: int = 2,
activation: str = 'relu',
use_batchnorm: bool = True):
super().__init__()
卷积块的构造函数:
in_channels
: 输入通道数out_channels
: 输出通道数kernel_size
: 卷积核大小stride
: 步长padding
: 填充pool_size
: 池化大小activation
: 激活函数类型use_batchnorm
: 是否使用批归一化
super().__init__()
调用父类的构造函数,这是 PyTorch 模块的标准写法。
# 构建卷积层
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
创建卷积层,参数从构造函数获取。
# 批归一化层(可选)
self.bn = nn.BatchNorm2d(out_channels) if use_batchnorm else None
如果需要批归一化,就创建批归一化层,否则设为 None。
# 激活函数
self.activation = self._get_activation(activation)
根据名称获取对应的激活函数,调用内部方法_get_activation
。