pytorch数据预处理——3. Pytorch Dataset类

本专题主要是解决Pytorch框架下项目的数据预处理工作
Table of Contents:
     1. HDF5文件简介
     2. Python中的_, __, __xx__区别
     3. Dataset类
     4. DataLoader类

有了前面两节内容的铺垫,这一节内容应该就比较好理解了。
本文主要将如何在Pytorch中自定义一个Dataset类来封装训练数据与测试数据,即使得数据保存在一个自定义数据类型对象(类似列表、字典等类型对象)里面。

先以一个简单的不可变容器对象为例,在Pytorch数据预处理中通常都是定义一个不可变容器对象。

# h5dataset.py
import h5py
import torch
import torch.utils.data as data
import glob
import os
from common import *
import numpy as np


# import nibabel as nib
# BCHW order
class H5Dataset(data.Dataset):

    def __init__(self, root_path, crop_size=crop_size, mode='train'):
        self.hdf5_list = [x for x in glob.glob(os.path.join(root_path, '*.h5'))]
        self.crop_size = crop_size
        self.mode = mode
        if self.mode == 'train':
            self.hdf5_list = self.hdf5_list + self.hdf5_list + self.hdf5_list + self.hdf5_list

    def __getitem__(self, index):
        h5_file = h5py.File(self.hdf5_list[index])
        self.data = h5_file.get('data')
        self.label = h5_file.get('label')
        self.label = self.label[:, 0, ...]
        _, _, C, H, W = self.data.shape
        if self.mode == 'train':
            cx = random.randint(0, C - self.crop_size[0])
            cy = random.randint(0, H - self.crop_size[1])
            cz = random.randint(0, W - self.crop_size[2])

        elif self.mode == 'val':
            # -------Center crop----------
            cx = (C - self.crop_size[0]) // 2
            cy = (H - self.crop_size[1]) // 2
            cz = (W - self.crop_size[2]) // 2

        self.data_crop = self.data[:, :, cx: cx + self.crop_size[0], cy: cy + self.crop_size[1],
                         cz: cz + self.crop_size[2]]
        self.label_crop = self.label[:, cx: cx + self.crop_size[0], cy: cy + self.crop_size[1],
                          cz: cz + self.crop_size[2]]
        # print(self.data_crop.shape, self.label_crop.shape)
        # ------End random crop-------------
        # h5_file.close()
        return (torch.from_numpy(self.data_crop[0, :, :, :, :]).float(),
                torch.from_numpy(self.label_crop[0, :, :, :]).long())

    def __len__(self):
        return len(self.hdf5_list)

对于这个自定义的 H5Dataset 类有以下几点说明:

  1. def __getitem__(self, index): 方法返回一个训练样本,即在自定义容器 dataset 执行 dataset[i] 时自动调用该方法,然后返回训练样本。(注意:这里一个 *.h5 文件存储一个样本。如何一个文件存储多个样本怎么办?如何保证遍历完这个文件里所有样本都被遍历?)
  2. 参数 mode 决定返回训练样本还是验证样本。
  3. 对于不可变容器来说,这三个函数是必须要的,其他功能可以增添函数来完善。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值