MNIST 数据集下载
时间: 2025-05-21 22:34:18 浏览: 48
### 如何下载并使用MNIST数据集
MNIST 数据集是一个广泛使用的手写数字识别数据集,包含 60,000 张训练图像和 10,000 张测试图像,每张图片大小为 \(28 \times 28\) 像素。以下是关于如何下载和加载 MNIST 数据集的具体说明。
#### 使用 Python 和 `tensorflow` 或 `torchvision` 下载 MNIST 数据集
如果正在使用 TensorFlow 或 PyTorch 这样的框架,则可以通过内置工具轻松获取 MNIST 数据集:
##### 方法一:通过 TensorFlow/Keras 加载 MNIST 数据集
TensorFlow 提供了一个简单的方法来加载 MNIST 数据集:
```python
from tensorflow.keras.datasets import mnist
# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(f"Training data shape: {x_train.shape}") # 输出应为 (60000, 28, 28)[^2]
print(f"Test data shape: {x_test.shape}") # 输出应为 (10000, 28, 28)
```
##### 方法二:通过 TorchVision 加载 MNIST 数据集
对于 PyTorch 用户,可以利用 `torchvision.datasets.MNIST` 来完成相同的操作:
```python
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
```
以上两种方式均会自动检测本地是否存在 MNIST 数据集;如果没有找到对应文件,则会在指定路径中下载它[^3]。
#### 手动下载 MNIST 数据集
当无法依赖第三方库时,也可以手动访问原始资源链接进行下载。官方提供了四个主要文件用于构建完整的 MNIST 集合:
- 训练标签 (`train-labels.idx1-ubyte`)
- 训练样本 (`train-images.idx3-ubyte`)
- 测试标签 (`t10k-labels.idx1-ubyte`)
- 测试样本 (`t10k-images.idx3-ubyte`)
可以从 Yann LeCun 的个人主页或者镜像站点获取这些压缩包,并解压至项目工作区内的特定子目录下(如上述提到的 `dataset/mnist/` 文件夹结构)[^3]。
之后可编写自定义解析脚本来读取 `.idx` 类型的数据流转换成 NumPy 数组形式以便后续处理分析。
---
### 示例代码展示
下面给出一段基于纯 Python 实现的手动加载 MNIST 数据的例子:
```python
import gzip
import numpy as np
def load_images(filename):
with gzip.open(filename, 'rb') as f:
_ = int.from_bytes(f.read(4), byteorder='big')
nimages = int.from_bytes(f.read(4), byteorder='big')
rows = int.from_bytes(f.read(4), byteorder='big')
cols = int.from_bytes(f.read(4), byteorder='big')
buf = f.read(nimages * rows * cols)
images = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
images = images.reshape(nimages, rows*cols)
return images / 255.
def load_labels(filename):
with gzip.open(filename, 'rb') as f:
_ = int.from_bytes(f.read(4), byteorder='big')
nlabels = int.from_bytes(f.read(4), byteorder='big')
labels = np.array(list(f.read(nlabels)), dtype=np.int64)
return labels
X_train = load_images('train-images.gz')
y_train = load_labels('train-labels.gz')
X_test = load_images('test-images.gz')
y_test = load_labels('test-labels.gz')
```
此段程序假设所有必要的 .gz 格式的 idx 文档已经被放置在同一级目录里[^3]。
---
阅读全文
相关推荐















