一、MNIST数据集介绍
首先你需要从官网上下载mnist相应文件,一共四个文件。其中训练数据集包含60000个数字图片,测试数据集包含10000个图片用作检测使用,下图是相应文件及其相应文件存储方式。
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 10000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 10000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
可以看出这四个文件前几个字节存贮的都是文件内元素个体的属性,比如图片的长和宽,存储元素的总数(训练图片有60000个,测试图片有10000个以及他们对应的标签集),之后数据集才开始存储。
二、读取MNIST数据集
import os
import struct
import numpy as np
import matplotlib.pyplot as plt
def load_mnist(path,kind = 'train'): #设置kind的原因:方便我们之后打开测试集数据,扩展程序
"""Load MNIST data from path"""
"""os.path.join为合并括号里面的所有路径"""
labels_path = os.path.join(path,'%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path,'%s-images.idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
# 'I'表示一个无符号整数,大小为四个字节
# '>II'表示读取两个无符号整数,即8个字节
#将文件中指针定位到数据集开头处,file.read(8)就是把文件的读取指针放到第九个字节开头处
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype = np.uint8)
print(magic, n) #便于读者知道这些对应是文件中的那些内容
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))
images = np.fromfile(imgpath, dtype = np.uint8).reshape(len(labels), 784)
print(magic, num, rows, cols) #便于读者知道这些对应是文件中的那些内容
return images, labels
X_train, y_train = load_mnist('D:\Program Files\JetBrains\My Project\mnist_data',kind = 'train')
X_test, y_test = load_mnist('D:\Program Files\JetBrains\My Project\mnist_data',kind = 't10k')
fig, ax = plt.subplots(nrows = 2,ncols = 5,sharex = True,sharey=True)
ax = ax.flatten() #将2X5矩阵拉伸成元组形式,以便之后迭代
for i in range(10):
img = X_train[y_train == i][0].reshape(28, 28)
#imshow:cmap 代表绘图的样式; interpolation:代表插值的方法
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域
plt.show()
打印出来的内容如下,有两次重复是因为程序中调用了定义函数两次
X_train, y_train = load_mnist('D:\Program Files\JetBrains\My Project\mnist_data',kind = 'train')
X_test, y_test = load_mnist('D:\Program Files\JetBrains\My Project\mnist_data',kind = 't10k')
2049 60000
2051 60000 28 28
2049 10000
2051 10000 28 28

自己对程序中的疑问与猜测:如下
img = X_train[y_train == i][0].reshape(28, 28)
其中x_train是图片(函数返回结果images)y_train是标签(函数返回结果labels)reshape是对图片大小的重新定义
让我疑问的是 X_train [y_train == i][0]其中的[y_train == i],其中i是迭代变量,从0-9依次增加,其实现的功能应该是在第一列中寻找标签为0的图片。我只是这样认为的,但是我在网上并没找到相关详细内容,不知道正确与否,有大神可以详细解释一下,我理解的是否错误。
本文借鉴了一些相关博客如下:
https://blog.csdn.net/qq_33254870/article/details/81388620
https://blog.csdn.net/simple_the_best/article/details/75267863