PyTorch Lightning 生命周期流程
时间: 2025-04-30 20:27:41 浏览: 27
### PyTorch Lightning 生命周期阶段及其执行顺序
PyTorch Lightning 的设计旨在简化机器学习模型训练过程中的复杂度,通过定义清晰的生命周期来管理实验。以下是主要的生命期阶段以及它们的执行顺序:
#### 初始化 Trainer 和 Model
在创建 `Trainer` 实例时可以指定多种参数配置,这些设置会决定后续训练行为[^1]。
```python
trainer = pl.Trainer(
max_epochs=5,
gpus=1, # 使用 GPU 数量
logger=WandbLogger(), # 日志记录器
)
model = MyModel()
```
#### 配置优化器和调度器
当调用 `configure_optimizers()` 方法返回优化算法实例或者包含优化器与学习率调整策略在内的字典结构。
```python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
scheduler = {
'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': 'val_loss'
}
return [optimizer], [scheduler]
```
#### 训练循环前准备
- **on_fit_start()**: 整体 fit 过程开始之前被触发。
- **setup(stage)**: 可用于数据集划分或其他初始化工作,在每个阶段(train/validate/test/predict)启动前运行一次。
#### 单轮 Epoch 开始
对于每一个 epoch 来说有如下钩子函数:
- **on_train_epoch_start()**
接着进入 batch-level 循环处理各个 mini-batch 数据直到完成整个 dataset 的遍历。
#### Batch 处理逻辑
针对每一批次的数据存在以下方法供开发者重写实现自定义操作:
- **training_step(batch, batch_idx)**
- **validation_step(batch, batch_idx)**
- **test_step(batch, batch_idx)**
- **predict_step(batch, batch_idx)**
上述四个步骤分别对应于不同模式下的单步计算逻辑,其中 training_step 是最核心的部分因为它涉及到反向传播更新权重等重要环节。
#### 批次结束后的汇总统计
每当一个完整的 epoch 结束之后都会依次调用下面的方法来进行性能评估并保存checkpoint等工作:
- **on_validation_end()**
- **on_test_end()**
最后在整个训练周期结束后还会有一个最终收尾动作即 on_fit_end()。
#### 测试预测阶段
除了常规训练外还可以单独开启测试或推理流程,此时仅需提供相应 dataloader 并调用 trainer.test 或者 predict 接口即可。
```python
results = trainer.test(model=model, dataloaders=test_dataloader)
predictions = trainer.predict(model=model, dataloaders=predict_dataloader)
```
阅读全文
相关推荐


















