一、权值初始化
在深度学习中,权值初始化是神经网络训练中关键的一步。正确的权值初始化可以帮助加速收敛速度,提高模型的稳定性和泛化能力。在PyTorch中,我们可以通过
torch.nn.init
模块中的函数来实现不同的权值初始化方法。1.常规初始化
PyTorch通常会自动处理权值的初始化,但我们也可以手动选择不同的初始化方法。常见的初始化方法包括正态分布和均匀分布。例如,我们可以使用
torch.nn.init.normal_()
来以正态分布初始化权值,或者使用torch.nn.init.uniform_()
以均匀分布初始化权值。import numpy as np def normal_init(shape): return np.random.normal(0, 1, size=shape) # 使用正态分布初始化示例 weight_shape = (256, 784) # 举例隐藏层到输入层权重的形状 normal_weights = normal_init(weight_shape) print("正态分布初始化权重:", normal_weights)
2.Xavier初始化
Xavier初始化是一种常用的权值初始化方法,旨在保持每一层输入输出的方差相等。这有助于避免梯度消失或梯度爆炸问题。在PyTorch中,我们可以使用
torch.nn.init.xavier_normal_()
或torch.nn.init.xavier_uniform_()
函数来实现Xavier初始化。import numpy as np def xavier_init(shape): fan_in = shape[0] fan_out = shape[1] limit = np.sqrt(6.0 / (fan_in + fan_out)) return np.random.uniform(-limit, limit, size=shape) # 使用Xavier初始化示例 weight_shape = (256, 784) # 举例隐藏层到输入层权重的形状 xavier_weights = xavier_init(weight_shape) print("Xavier初始化权重:", xavier_weights)
3.He初始化
He初始化是专门针对ReLU激活函数设计的初始化方法,旨在更好地适应ReLU函数的性质。它可以帮助网络学习更快并提高性能。在PyTorch中,我们可以使用
torch.nn.init.kaiming_normal_()
或torch.nn.init.kaiming_uniform_()
函数来实现He初始化。import numpy as np def he_init(shape): fan_in = shape[0] limit = np.sqrt(2.0 / fan_in) return np.random.normal(0, limit, size=shape) # 使用He初始化示例 weight_shape = (256, 784) # 举例隐藏层到输入层权重的形状 he_weights = he_init(weight_shape) print("He初始化权重:", he_weights)
实例化代码
import torch import torch.nn as nn # 定义一个简单的神经网络模型 class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(784, 256) # 输入层到隐藏层 self.fc2 = nn.Linear(256, 10) # 隐藏层到输出层 def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # 初始化模型并应用Xavier初始化 model = SimpleNN() for name, param in model.named_parameters(): if 'weight' in name: nn.init.xavier_normal_(param) # 使用Xavier正态分布初始化权重 # 打印模型的权重 for name, param in model.named_parameters(): if 'weight' in name: print(name, param) # 模型的使用示例 input_data = torch.randn(1, 784) # 生成随机输入数据 output = model(input_data) # 进行模型推断 print("模型输出:", output)
运行结果
fc1.we