用一个实际例子和简单代码来清晰解释 batch、epoch 和 iteration 的关系:
------------------------------------------------------------------------------------
假设场景
-
你有一个数据集:1000 张猫狗图片
-
你设置 batch_size = 100(每次处理 100 张图片)
-
你计划训练 5 个 epoch
概念关系图解
1个Epoch = 完整遍历整个数据集1次 │ ├── Iteration 1: 处理第1-100张图片 (batch 1) ├── Iteration 2: 处理第101-200张图片 (batch 2) ├── ... └── Iteration 10: 处理第901-1000张图片 (batch 10)
具体计算
-
总样本数:1000 张图片
-
Batch size:100(每次处理的图片数)
-
Iterations per epoch = 总样本数 / batch_size = 1000 / 100 = 10 次
-
Total iterations = Iterations per epoch × Epochs = 10 × 5 = 50 次
代码示例说明
import torch
from torch.utils.data import DataLoader, TensorDataset# 创建模拟数据集:1000张图片(用1000个数字代替)
data = torch.arange(0, 1000) # [0, 1, 2, ..., 999]
dataset = TensorDataset(data)# 创建数据加载器:设置batch_size=100
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)# 训练5个epoch
for epoch in range(5):
print(f"\n=== 开始第 {epoch+1} 个epoch ===")
# 每个epoch内遍历所有batch
for batch_idx, batch_data in enumerate(dataloader):
# 获取当前batch的数据
images = batch_data[0] # 当前batch的100张"图片"
# 这里应该是你的训练代码:
# 1. 正向传播
# 2. 计算损失
# 3. 反向传播
# 4. 更新权重
# 打印当前iteration信息
print(f"Epoch {epoch+1} | Iteration {batch_idx+1} | 处理图片: {images[0].item()}到{images[-1].item()}")print("\n训练完成!")
print(f"总Iteration次数: {5 * len(dataloader)} 次")
import torch
from torch.utils.data import DataLoader, TensorDataset
# 创建模拟数据集:1000张图片(用1000个数字代替)
data = torch.arange(0, 1000) # [0, 1, 2, ..., 999]
dataset = TensorDataset(data)
# 创建数据加载器:设置batch_size=100
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
# 训练5个epoch
for epoch in range(5):
print(f"\n=== 开始第 {epoch+1} 个epoch ===")
# 每个epoch内遍历所有batch
for batch_idx, batch_data in enumerate(dataloader):
# 获取当前batch的数据
images = batch_data[0] # 当前batch的100张"图片"
# 这里应该是你的训练代码:
# 1. 正向传播
# 2. 计算损失
# 3. 反向传播
# 4. 更新权重
# 打印当前iteration信息
print(f"Epoch {epoch+1} | Iteration {batch_idx+1} | 处理图片: {images[0].item()}到{images[-1].item()}")
print("\n训练完成!")
print(f"总Iteration次数: {5 * len(dataloader)} 次")
输出示例
=== 开始第 1 个epoch ===
Epoch 1 | Iteration 1 | 处理图片: 12到980
Epoch 1 | Iteration 2 | 处理图片: 88到799
...
Epoch 1 | Iteration 10 | 处理图片: 36到995
=== 开始第 2 个epoch ===
Epoch 2 | Iteration 1 | 处理图片: 44到932
...
(共5个epoch,每个epoch 10个iteration)
训练完成!
总Iteration次数: 50 次
关键概念解析
-
Batch(批):
-
每次实际输入模型的数据子集
-
代码中的
images = batch_data[0]
获取的就是一个batch -
大小由
batch_size=100
决定
-
-
Iteration(迭代):
-
完成一个batch训练所需的步骤
-
每个iteration包含:
-
从dataloader取出一个batch数据
-
执行正向传播 → 计算损失 → 反向传播 → 更新权重
-
-
代码中
for batch_idx, ...
循环的每次执行就是一个iteration
-
-
Epoch(轮次):
-
完整遍历整个数据集一次
-
每个epoch包含多个iteration:
-
1000样本 ÷ 100 batch_size = 10 iterations/epoch
-
-
外层循环
for epoch in range(5)
控制epoch数量
-
为什么需要batch?
-
内存限制:无法一次性加载所有数据(如100万张图片)
-
训练效率:小批量数据更适合GPU并行计算
-
梯度稳定性:批量梯度下降比单个样本更稳定
-
正则化效果:小批量带来轻微噪声,有助于防止过拟合
实际训练中的选择
Batch Size | 迭代次数 | 内存占用 | 训练稳定性 |
---|---|---|---|
小 (8-32) | 多 | 低 | 较低 |
中 (64-256) | 中等 | 中等 | 较好 |
大 (512+) | 少 | 高 | 高 |
初学者建议从 batch_size=64 开始尝试,这是常用基准值。你之前提到的8也是合理的,尤其当使用大模型或显存有限时。