这行代码 self.log('lr_abs', lr, ...) 是 PyTorch Lightning 中用于记录学习率的关键操作。
一、核心功能与参数
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
1. 参数解释:
- 'lr_abs':日志指标的名称(在 TensorBoard 或进度条中显示为学习率)。
- lr:当前的学习率值,通常从优化器中获取(如 optimizer.param_groups[0]['lr'])。
- prog_bar=True:在终端的进度条上显示该指标。
- logger=True:将指标写入日志文件(如 TensorBoard/WandB)。
- on_step=True:每处理一个批次就记录一次。
- on_epoch=False:不在每个 epoch 结束时重复记录。
二、学习率的重要性
学习率(Learning Rate, LR)是深度学习训练中最关键的超参数之一,它控制着:
1.参数更新的步长:
- LR 过大:参数更新步长过大,可能导致损失震荡或无法收敛。
- LR 过小:参数更新步长过小,训练速度缓慢,可能陷入局部最优。
2.与优化器的协同作用:
- 不同优化器(如 Adam、SGD)对学习率的敏感性不同。
- 例如:Adam 通常使用 0.001,而 SGD 可能需要 0.01 或更高。
三、为什么需要记录学习率?
1. 监控学习率调度器
许多训练策略会动态调整学习率,如:
- 余弦退火(Cosine Annealing):学习率随训练进程呈余弦曲线下降。
- 阶梯衰减(Step Decay):在特定步数或 epoch 时学习率减半。
- ReduceLROnPlateau:当验证损失停滞时自动降低学习率。
记录学习率可以验证调度器是否按预期工作。
2. 诊断训练问题
- 损失突然增大:可能是学习率过大导致参数更新不稳定。
- 训练停滞:可能是学习率过小,模型无法跳出局部最优。
通过观察学习率与损失的关系,可以快速定位问题。
3. 复现实验
学习率是实验配置的重要组成部分,记录它有助于:
- 准确复现训练过程。
- 比较不同学习率策略的效果。
四、实际运行效果
1. 终端进度条显示
Epoch 2: 30%|███ | 30/100 [00:05<00:12, 5.98it/s, lr_abs=0.0008]
实时显示当前学习率(如 lr_abs=0.0008)。
五、学习率监控的最佳实践
1. 多参数组学习率
如果模型不同部分使用不同学习率(如预训练骨干网络和新分类头):
# 记录所有参数组的学习率
for i, param_group in enumerate(self.optimizers().param_groups):
self.log(f'lr_group_{i}', param_group['lr'], on_step=True, on_epoch=False)
2. 结合损失曲线分析
在 TensorBoard 中同时查看学习率和损失曲线:
- 当学习率下降时,损失应该继续减小。
- 如果学习率下降后损失仍停滞,可能需要更激进的学习率调整。
3. 不同阶段的学习率策略
- 预热阶段(Warmup):训练初期使用较小的学习率逐步增加。
- 微调阶段(Fine-tuning):使用比预训练更小的学习率。
七、总结
记录学习率是深度学习训练中的关键实践,它提供了:
- 训练过程的透明度:了解学习率如何随时间变化。
- 调试工具:通过学习率与损失的关系诊断训练问题。
- 实验可复现性:准确记录超参数配置。
在 PyTorch Lightning 中,通过 self.log('lr_abs', lr) 简单一行代码,即可实现这一重要功能。