
model_args.compute_dtype (比如设置为 torch.bfloat16 或 torch.float16) 控制了 模型的计算精度,包括前向传播、损失计算等,梯度等。对于 优化器 的精度,通常不直接受到影响。优化器(如 Adam、AdamW)的参数更新依然使用 float32 精度
| 项 | 是否受fp16影响 |
|---|---|
| 模型参数 | fp16 |
| 输出 | fp16 |
| 损失函数计算 | fp16 |
| 梯度 | fp16 |
| 优化器 | fp32,不受影响 |

model_args.compute_dtype (比如设置为 torch.bfloat16 或 torch.float16) 控制了 模型的计算精度,包括前向传播、损失计算等,梯度等。对于 优化器 的精度,通常不直接受到影响。优化器(如 Adam、AdamW)的参数更新依然使用 float32 精度
| 项 | 是否受fp16影响 |
|---|---|
| 模型参数 | fp16 |
| 输出 | fp16 |
| 损失函数计算 | fp16 |
| 梯度 | fp16 |
| 优化器 | fp32,不受影响 |
2370
3189

被折叠的 条评论
为什么被折叠?