class MyDataSet3D(Dataset): """自定义用于3DCT数据集"""
时间: 2025-02-17 19:09:39 浏览: 52
### 创建继承自 PyTorch `Dataset` 类的自定义 3D CT 数据集
为了创建一个继承自 PyTorch 的 `torch.utils.data.Dataset` 类的自定义 3D CT 数据集,需要重写两个主要方法:`__init__`, `__len__` 和 `__getitem__` 方法。
#### 初始化 (`__init__`) 方法
初始化函数负责加载数据并执行任何必要的预处理操作。对于 3D CT 数据集来说,这通常意味着读取文件路径列表、标签和其他元数据。
```python
import os
from torch.utils.data import Dataset
import nibabel as nib
import numpy as np
class Custom3DCTDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.transform = transform
self.image_paths = []
self.label_paths = []
# 假设图像和标签在同一目录下,并且有相同的名称前缀
for file_name in os.listdir(data_dir):
if 'image' in file_name and file_name.endswith('.nii.gz'):
image_path = os.path.join(data_dir, file_name)
label_file_name = file_name.replace('image', 'label')
label_path = os.path.join(data_dir, label_file_name)
if os.path.exists(label_path): # 只添加存在对应标签的数据项
self.image_paths.append(image_path)
self.label_paths.append(label_path)
```
此部分代码构建了一个包含所有可用样本及其相应标注路径的列表[^1]。
#### 获取长度 (`__len__`) 方法
该方法返回数据集中项目的总数目:
```python
def __len__(self):
return len(self.image_paths)
```
通过这种方式可以知道有多少个独立的训练/测试实例被加入到了这个定制化的数据集中。
#### 获取项目 (`__getitem__`) 方法
这是最重要的部分之一,在这里实现了如何访问单个项目以及对其进行转换(如果有的话)。下面是一个简单的例子来展示这一点:
```python
def __getitem__(self, idx):
img_path = self.image_paths[idx]
lbl_path = self.label_paths[idx]
# 加载NIfTI格式的三维医学影像数据
img_nifti = nib.load(img_path).get_fdata()
lbl_nifti = nib.load(lbl_path).get_fdata()
sample = {'image': img_nifti.astype(np.float32), 'label': lbl_nifti}
if self.transform:
sample = self.transform(sample)
return sample['image'], sample['label']
```
这段代码展示了如何从指定位置获取一对输入特征 (即CT扫描) 和对应的输出目标 (如病变区域分割)。
阅读全文
相关推荐


















