本专题主要是解决Pytorch框架下项目的数据预处理工作
Table of Contents:
1. HDF5文件简介
2. Python中的_, __, __xx__区别
3. Dataset类
4. DataLoader类
有了前面两节内容的铺垫,这一节内容应该就比较好理解了。
本文主要将如何在Pytorch中自定义一个Dataset类来封装训练数据与测试数据,即使得数据保存在一个自定义数据类型对象(类似列表、字典等类型对象)里面。
先以一个简单的不可变容器对象为例,在Pytorch数据预处理中通常都是定义一个不可变容器对象。
# h5dataset.py
import h5py
import torch
import torch.utils.data as data
import glob
import os
from common import *
import numpy as np
# import nibabel as nib
# BCHW order
class H5Dataset(data.Dataset):
def __init__(self, root_path, crop_size=crop_size, mode='train'):
self.hdf5_list = [x for x in glob.glob(os.path.join(root_path, '*.h5'))]
self.crop_size = crop_size
self.mode = mode
if self.mode == 'train':
self.hdf5_list = self.hdf5_list + self.hdf5_list + self.hdf5_list + self.hdf5_list
def __getitem__(self, index):
h5_file = h5py.File(self.hdf5_list[index])
self.data = h5_file.get('data')
self.label = h5_file.get('label')
self.label = self.label[:, 0, ...]
_, _, C, H, W = self.data.shape
if self.mode == 'train':
cx = random.randint(0, C - self.crop_size[0])
cy = random.randint(0, H - self.crop_size[1])
cz = random.randint(0, W - self.crop_size[2])
elif self.mode == 'val':
# -------Center crop----------
cx = (C - self.crop_size[0]) // 2
cy = (H - self.crop_size[1]) // 2
cz = (W - self.crop_size[2]) // 2
self.data_crop = self.data[:, :, cx: cx + self.crop_size[0], cy: cy + self.crop_size[1],
cz: cz + self.crop_size[2]]
self.label_crop = self.label[:, cx: cx + self.crop_size[0], cy: cy + self.crop_size[1],
cz: cz + self.crop_size[2]]
# print(self.data_crop.shape, self.label_crop.shape)
# ------End random crop-------------
# h5_file.close()
return (torch.from_numpy(self.data_crop[0, :, :, :, :]).float(),
torch.from_numpy(self.label_crop[0, :, :, :]).long())
def __len__(self):
return len(self.hdf5_list)
对于这个自定义的 H5Dataset 类有以下几点说明:
- def __getitem__(self, index): 方法返回一个训练样本,即在自定义容器 dataset 执行 dataset[i] 时自动调用该方法,然后返回训练样本。(注意:这里一个 *.h5 文件存储一个样本。如何一个文件存储多个样本怎么办?如何保证遍历完这个文件里所有样本都被遍历?)
- 参数 mode 决定返回训练样本还是验证样本。
- 对于不可变容器来说,这三个函数是必须要的,其他功能可以增添函数来完善。