pytorch的数据集
时间: 2025-04-28 11:20:39 浏览: 12
### 创建和使用PyTorch数据集
#### 数据准备
为了有效地利用PyTorch进行机器学习项目的数据准备工作,可以采用简单的方法来创建自定义数据集。具体来说,可以通过编写Python脚本来加载图像文件并应用必要的预处理操作[^1]。
```python
import glob
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
```
这段代码展示了导入所需库的过程,这些库对于后续的操作至关重要。通过`glob`模块获取文件列表;借助PIL库读取图片;运用Numpy执行数组运算;以及利用Matplotlib展示图像效果。同时,还引入了来自`torchvision.transforms`的各种变换工具用于增强数据多样性[^3]。
#### 构建自定义数据集类
当标准的内置数据集无法满足特定需求时,则需设计一个继承于`data.Dataset`的新类——这里命名为`MyDataset`。此类应重写两个核心方法:`__len__()`返回整个集合大小,而`__getitem__(self, idx)`负责按索引访问单条记录[^2]。
```python
class MyDataset(data.Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
path = self.image_paths[index]
img = Image.open(path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
```
上述实现了最基础版本的自定义数据集构造器,在初始化阶段接收一系列图像路径作为输入参数,并允许传入额外的转换规则应用于每张载入后的照片上[^4]。
#### 整合与迭代数据
完成以上步骤之后,便能够轻松地把实例化的对象传递给DataLoader来进行批量采样及随机打乱顺序等高级功能[^5]:
```python
dataset = MyDataset(image_paths=glob.glob('./images/*.jpg'),
transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
for images in dataloader:
# 进行训练或其他处理...
pass
```
此部分演示了如何将之前建立好的`MyDataset`实例化为实际可用的对象,并设置每次取出多少样本组成一批次(`batch_size`)还有是否开启混洗模式(shuffle)以提高泛化能力。最后进入循环结构逐批次取得待用资料供下游任务调用。
阅读全文
相关推荐















