Pytorch IterableDataset的使用

Pytorch IterableDataset的使用

背景

当数据量特别大,无法一次性load进内存时,Pytorch里的Dataset就无法胜任了,此时需要使用IterableDataset.

基本用法

只需要实现__init__()__iter__()__len__(),模版如下:

from torch.utils.data import IterableDataset, DataLoader

class MyIterableDataset(IterableDataset):
	def __init__(self):
		# 实现初始化代码
		pass
	
	def __iter__(self):
		# 返回一个数据的迭代器
		pass
	
	def __len__(self):
		# 返回数据长度
		pass

mydataset = MyIterableDataset()  # 可迭代对象
mydataloader = DataLoader(mydataset, shuffle=False, batch_size=batch_size, num_workers=num_workers)  # shuffle必须要是False

一个例子

读取CNN_Dailymail摘要数据集

class SummaryDataset(IterableDataset):                                          
                                                                                
    def __init__(self,                                                          
                 file_path: str                                                 
    ):                                                                          
        super(SummaryDataset).__init__()                                        
        self.file_path = file_path                                              
        self.info  = self._get_file_info(file_path)                             
        self.start = self.info['start']                                         
        self.end   = self.info['end']                                           
                                                                                
    def __iter__(self):                                                         
        worker_info = torch.utils.data.get_worker_info()                        
        if worker_info is None:  # single worker                                
            iter_start = self.start                                             
            iter_end   = self.end                                               
        else
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值