1. 前言
在 PyTorch 中,IterableDataset
是一种用于处理大量数据或无法一次性加载到内存中的数据集类。与 Dataset
不同的是,IterableDataset
是基于迭代器的,适用于顺序数据读取。
2. 基本使用方法
需要创建一个继承自 torch.utils.data.IterableDataset
的自定义数据集类,并实现 __iter__
方法。
import torch
from torch.utils.data import IterableDataset, DataLoader
class MyIterableDataset(IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
self.start = start
self.end = end
def __iter__(self):
# 返回一个生成器
return iter(range(self.start, self.end))
# 创建数据集实例
dataset = MyIterableDataset(start=0, end=<