pytorch set_epoch()方法

本文解释了在分布式学习环境下,如何在PyTorch中正确设置epoch并确保DistributedSampler的shuffle功能有效,避免数据加载顺序固定。务必在每个epoch开始时调用set_epoch(),以实现每轮数据的重新打乱。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        在分布式模式下,需要在每个 epoch 开始时调用 set_epoch() 方法,然后再创建 DataLoader 迭代器,以使 shuffle 操作能够在多个 epoch 中正常工作。 否则,dataloader迭代器产生的数据将始终使用相同的顺序。

        

sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None),
                    sampler=sampler)
for epoch in range(start_epoch, n_epochs):
    if is_distributed:
        sampler.set_epoch(epoch)
    train(loader)

参考:

        torch.utils.data — PyTorch 1.10.0 documentation

Pytorch DistributedDataParallel 数据采样 shuffle - 知乎更新:这是19年12月的文章,至今pytorch已经更新了不少,所以还有没有用请读者自己实验。 ———————————————————————————————————————— 原文 一个比较蛋疼的问题,就是我发…https://zhuanlan.zhihu.com/p/97115875 

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值