对mnist数据集整理成training.pt和test.pt
时间: 2025-03-07 08:13:03 浏览: 35
### 将MNIST数据集处理并保存为PyTorch的`training.pt`和`test.pt`文件
为了将MNIST数据集转换成适合长期存储以及快速读取的形式,可以利用Python中的pickle序列化机制或是更高效的torch.save方法来创建`.pt`格式的数据文件。具体实现如下:
#### 导入必要的库
```python
import torch
from torchvision import datasets, transforms
```
#### 定义数据预处理变换
定义标准化操作以及其他可能需要应用于图像的操作链。
```python
transform = transforms.Compose([
transforms.ToTensor(),
])
```
#### 加载原始MNIST数据集
使用内置函数加载训练集与测试集,并应用上述定义的转换规则[^2]。
```python
train_dataset = datasets.MNIST(root='./data',
train=True,
transform=transform,
download=True)
test_dataset = datasets.MNIST(root='./data',
train=False,
transform=transform)
```
#### 提取消耗时间较长的部分——即实际样本及其标签
提取出张量形式的数据以便后续保存。
```python
train_data = train_dataset.data.float() / 255.
train_labels = train_dataset.targets.long()
test_data = test_dataset.data.float() / 255.
test_labels = test_dataset.targets.long()
```
#### 构建字典结构用于保存
将图片数据与对应的类别标签组合起来形成易于管理的整体对象。
```python
train_dict = {
'images': train_data.unsqueeze(1), # 增加通道维度 (N,H,W)->(N,C,H,W),C=1对于灰度图像是固定的
'labels': train_labels
}
test_dict = {
'images': test_data.unsqueeze(1),
'labels': test_labels
}
```
#### 使用`torch.save()`持久化至磁盘
最终一步就是调用此API完成序列化过程并将结果存放在指定路径下的目标文件中。
```python
torch.save(train_dict, './data/training.pt')
torch.save(test_dict, './data/test.pt')
```
这样就完成了从下载到本地缓存再到自定义二进制格式(`.pt`)的一系列流程[^3]。
阅读全文
相关推荐











