import torch import torchvision.datasets as dsets import torchvision.transforms as transforms #/********** Begin *********/ # 下载MNIST数据集 Mnist_dataset = dsets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True) # 创建batch_size=100, shuffle=True的DataLoader类型的变量data_loader data_loader = torch.utils.data.DataLoader(dataset=Mnist_dataset,batch_size=100,shuffle=True) # 输出 data_loader中数据类型 print(type(data_loader.dataset)) #/********** End *********/ 输出结果<class 'torchvision.datasets.mnist.MNIST'> Congratulation!
时间: 2025-05-21 15:36:58 浏览: 25
### PyTorch 使用 `torchvision.datasets` 下载 MNIST 数据集并创建 DataLoader
以下是关于如何使用 PyTorch 中的 `torchvision.datasets` 来下载 MNIST 数据集,并将其封装到 `DataLoader` 的详细解释和示例代码。
#### 1. 导入库
为了实现数据集的下载与加载,需要引入必要的库:
```python
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
```
#### 2. 定义数据预处理操作
在实际应用中,通常会对图像进行一些标准化或增强的操作。对于 MNIST 数据集,常见的做法是将 PIL 图像转换为张量 (Tensor),并对像素值进行归一化。
```python
transform = transforms.Compose([
transforms.ToTensor(), # 将 PIL 图像转为 Tensor 并缩放到 [0, 1] 范围
transforms.Normalize((0.5,), (0.5,)) # 对单通道灰度图进行均值和标准差归一化
])
```
此处使用的 `transforms.Normalize` 是针对单通道灰度图像设计的,其参数 `(0.5,)` 表示均值,`(0.5,)` 表示标准差[^2]。
#### 3. 加载 MNIST 数据集
通过 `torchvision.datasets.MNIST` 类可以轻松下载和加载 MNIST 数据集。可以通过设置 `train=True` 或 `train=False` 来分别获取训练集和测试集。
```python
train_dataset = datasets.MNIST(
root='./data', # 存储路径
train=True, # 是否为训练集
transform=transform, # 预处理方法
download=True # 自动下载数据集
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform,
download=True
)
```
以上代码片段中的 `root` 参数指定了存储数据集的本地文件夹位置,而 `download=True` 则会在未找到数据时自动从网络下载[^3]。
#### 4. 创建 DataLoader
`DataLoader` 提供了一种高效的方式来批量加载数据,并支持多线程加速。
```python
batch_size = 64
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True # 训练过程中打乱顺序
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=False # 测试时不需打乱顺序
)
```
在此处,`shuffle=True` 确保每次迭代都会随机排列训练数据,从而提高模型泛化能力[^1]。
#### 5. 确认数据类型
MNIST 数据集中每幅图片大小为 \(28 \times 28\) 像素,经过 `ToTensor()` 处理后会被转化为形状为 `[batch_size, 1, 28, 28]` 的浮点型张量(其中第一个维度表示批次大小)。标签则是长度等于批次大小的一维整数向量。
```python
for images, labels in train_loader:
print(f'Images shape: {images.shape}') # 输出形如:torch.Size([64, 1, 28, 28])
print(f'Labels shape: {labels.shape}') # 输出形如:torch.Size([64])
break # 只查看第一批数据即可
```
---
### 总结
上述过程展示了完整的流程,包括定义数据预处理方式、下载 MNIST 数据集、构建 `DataLoader` 实现批量化加载以及验证最终的数据形式。这是一套典型的机器学习项目前期准备工作的实践方案。
---
阅读全文
相关推荐
















