在深度学习中,batch、epoch 和 iteration 的关系

用一个实际例子和简单代码来清晰解释 batch、epoch 和 iteration 的关系:

------------------------------------------------------------------------------------

假设场景

  • 你有一个数据集:1000 张猫狗图片

  • 你设置 batch_size = 100(每次处理 100 张图片)

  • 你计划训练 5 个 epoch

概念关系图解

1个Epoch = 完整遍历整个数据集1次
    │
    ├── Iteration 1: 处理第1-100张图片 (batch 1)
    ├── Iteration 2: 处理第101-200张图片 (batch 2)
    ├── ...
    └── Iteration 10: 处理第901-1000张图片 (batch 10)

具体计算

  • 总样本数:1000 张图片

  • Batch size:100(每次处理的图片数)

  • Iterations per epoch = 总样本数 / batch_size = 1000 / 100 = 10 次

  • Total iterations = Iterations per epoch × Epochs = 10 × 5 = 50 次

代码示例说明

import torch
from torch.utils.data import DataLoader, TensorDataset

# 创建模拟数据集:1000张图片(用1000个数字代替)
data = torch.arange(0, 1000)  # [0, 1, 2, ..., 999]
dataset = TensorDataset(data)

# 创建数据加载器:设置batch_size=100
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 训练5个epoch
for epoch in range(5):
    print(f"\n=== 开始第 {epoch+1} 个epoch ===")
    
    # 每个epoch内遍历所有batch
    for batch_idx, batch_data in enumerate(dataloader):
        # 获取当前batch的数据
        images = batch_data[0]  # 当前batch的100张"图片"
        
        # 这里应该是你的训练代码:
        # 1. 正向传播
        # 2. 计算损失
        # 3. 反向传播
        # 4. 更新权重

        
        # 打印当前iteration信息
        print(f"Epoch {epoch+1} | Iteration {batch_idx+1} | 处理图片: {images[0].item()}到{images[-1].item()}")

print("\n训练完成!")
print(f"总Iteration次数: {5 * len(dataloader)} 次")

import torch
from torch.utils.data import DataLoader, TensorDataset

# 创建模拟数据集:1000张图片(用1000个数字代替)
data = torch.arange(0, 1000)  # [0, 1, 2, ..., 999]
dataset = TensorDataset(data)

# 创建数据加载器:设置batch_size=100
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 训练5个epoch
for epoch in range(5):
    print(f"\n=== 开始第 {epoch+1} 个epoch ===")
    
    # 每个epoch内遍历所有batch
    for batch_idx, batch_data in enumerate(dataloader):
        # 获取当前batch的数据
        images = batch_data[0]  # 当前batch的100张"图片"
        
        # 这里应该是你的训练代码:
        # 1. 正向传播
        # 2. 计算损失
        # 3. 反向传播
        # 4. 更新权重
        
        # 打印当前iteration信息
        print(f"Epoch {epoch+1} | Iteration {batch_idx+1} | 处理图片: {images[0].item()}到{images[-1].item()}")

print("\n训练完成!")
print(f"总Iteration次数: {5 * len(dataloader)} 次")

输出示例

=== 开始第 1 个epoch ===
Epoch 1 | Iteration 1 | 处理图片: 12到980
Epoch 1 | Iteration 2 | 处理图片: 88到799
...
Epoch 1 | Iteration 10 | 处理图片: 36到995

=== 开始第 2 个epoch ===
Epoch 2 | Iteration 1 | 处理图片: 44到932
...
(共5个epoch,每个epoch 10个iteration)

训练完成!
总Iteration次数: 50 次

关键概念解析

  1. Batch(批)

    • 每次实际输入模型的数据子集

    • 代码中的 images = batch_data[0] 获取的就是一个batch

    • 大小由 batch_size=100 决定

  2. Iteration(迭代)

    • 完成一个batch训练所需的步骤

    • 每个iteration包含:

      • 从dataloader取出一个batch数据

      • 执行正向传播 → 计算损失 → 反向传播 → 更新权重

    • 代码中 for batch_idx, ... 循环的每次执行就是一个iteration

  3. Epoch(轮次)

    • 完整遍历整个数据集一次

    • 每个epoch包含多个iteration:

      • 1000样本 ÷ 100 batch_size = 10 iterations/epoch

    • 外层循环 for epoch in range(5) 控制epoch数量

为什么需要batch?

  1. 内存限制:无法一次性加载所有数据(如100万张图片)

  2. 训练效率:小批量数据更适合GPU并行计算

  3. 梯度稳定性批量梯度下降比单个样本更稳定

  4. 正则化效果:小批量带来轻微噪声,有助于防止过拟合

实际训练中的选择

Batch Size迭代次数内存占用训练稳定性
小 (8-32)较低
中 (64-256)中等中等较好
大 (512+)

初学者建议从 batch_size=64 开始尝试,这是常用基准值。你之前提到的8也是合理的,尤其当使用大模型或显存有限时

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值