import torch import torchvision.datasets as dsets import torchvision.transforms as transforms import os import sys path = os.path.split(os.path.abspath(os.path.realpath(sys.argv[0])))[0] + os.path.sep path = path[:-10] + '/data/' #/********** Begin *********/ # ********** End ********** #
时间: 2025-05-20 11:38:27 浏览: 25
### 加载和处理 MNIST 数据集并创建 DataLoader 变量
以下是关于如何在 PyTorch 中加载和处理 MNIST 数据集的具体说明,包括创建 `DataLoader` 对象以及输出数据维度大小和类别的方法。
---
#### 1. 导入必要库
首先需要导入用于数据加载和处理的相关模块。
```python
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
```
---
#### 2. 定义数据转换操作
为了将图像数据转换为适合神经网络输入的形式,并对其进行标准化处理,可以定义一个组合的变换函数。
```python
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为 Tensor,并缩放到 [0, 1] 范围
transforms.Normalize((0.5,), (0.5,)) # 进一步归一化到 [-1, 1] 范围
])
```
此处使用的 `transforms.Normalize` 是针对单通道灰度图像设计的,均值 `(0.5,)` 和标准差 `(0.5,)` 的设置是为了使像素值分布更加集中于零附近[^1]。
---
#### 3. 下载和加载 MNIST 数据集
利用 `torchvision.datasets.MNIST` 类来下载和加载 MNIST 数据集。如果本地不存在该数据集,则会自动从官方源下载;否则直接使用已有的缓存文件。
```python
# 训练集
train_dataset = datasets.MNIST(
root='./data', # 存储路径
train=True, # 是否为训练集
transform=transform, # 应用的数据转换
download=True # 自动下载开关
)
# 测试集
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform,
download=True
)
```
---
#### 4. 创建 DataLoader 对象
通过 `DataLoader` 封装数据集实例,从而实现批量加载、随机打乱等功能。
```python
train_loader = DataLoader(
dataset=train_dataset, # 输入数据集
batch_size=64, # 每批次样本数量
shuffle=True, # 是否在每轮 epoch 开始前重新排列数据顺序
num_workers=2 # 后台线程数(加速数据读取)
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=1000,
shuffle=False
)
```
在此处设置了两个主要参数:`batch_size` 控制每次传递给模型的样本数目,而 `shuffle` 则决定是否对数据进行随机重排以打破可能存在的序列依赖关系[^2]。
---
#### 5. 输出特定数据的维度大小和类别
遍历 `DataLoader` 获取第一批数据,并打印其形状及对应的标签信息。
```python
for images, labels in train_loader:
print(f"Images shape: {images.shape}") # 图像张量尺寸
print(f"Labels: {labels[:10]}") # 前十个标签作为示例显示
break # 仅查看首个批次的内容
```
假设当前批次大小为 64,则预期输出如下所示:
```
Images shape: torch.Size([64, 1, 28, 28]) # 表明有 64 张图片,每张大小为 28x28 单通道灰度图
Labels: tensor([7, 9, 3, ..., 2, 5, 1]) # 展现部分真实数字标签
```
以上即完成了整个流程——从数据准备到最后验证阶段的关键步骤概述[^1]。
---
###
阅读全文
相关推荐

















