一、直接上全码 已注解
import tensorflow as tf
import numpy as np
from PIL import image
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model
model_save_path = './checkpoint/Baseline.ckpt'
class Baseline(Model):
def __init__(self):
super(Baseline, self).__init__()
self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same')
self.b1 = BatchNormalization()
self.a1 = Activation('relu')
self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
self.d1 = Dropout(0.2)
self.flatten = Flatten()
self.f1 = Dense(128, activation='relu')
self.d2 = Dropout(0.2)
self.f2 = Dense(10, activation='softmax')
def call(self, x):
x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x)
x = self.d1(x)
x = self.flatten(x)
x = self.f1(x)
x = self.d2(x)
y = self.f2(x)
return y
model = Baseline()
model.load_weights(model_save_path)
test_image = image.load_img('马.jpg', target_size=(32, 32))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
prediction = model.predict(test_image)
pred = tf.argmax(prediction, axis=1)
tf.print("预测类别为", pred)
二、知乎大神输入n张图片,进行批量预测
file_path = 'D:/Data/dogs/'
f_names = glob.glob(file_path + '*.jpg')
imgs = []
for i in range(len(f_names)):
img = image.load_img(f_names[i], target_size=(224, 224))
arr_img = image.img_to_array(img)
arr_img = np.expand_dims(arr_img, axis=0)
imgs.append(arr_img)
print("loading no.%s image."%i)
x = np.concatenate([x for x in imgs])
print("predicting...")
model = ResNet50(weights='imagenet')
y = model.predict(x)
print("Completed!")
这里使用了concatenate方法将所有1*224*224*3的图片连接成了all*224*224*3
在此附上知乎原文连接
三、在此附上py数据集存储为图片的代码
供自己复习

import numpy as np
import pickle
import imageio
def unpickle(file):
fo = open(file, 'rb')
dict = pickle.load(fo, encoding='latin1')
fo.close()
return dict
for j in range(1, 6):
dataName = "data_batch_" + str(j)
Xtr = unpickle(dataName)
print(dataName + " is loading...")
for i in range(0, 10000):
img = np.reshape(Xtr['data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0)
picName = 'train/' + str(Xtr['labels'][i]) + '_' + str(i + (j - 1)*10000) + '.jpg'
imageio.imwrite(picName, img)
print(dataName + " loaded.")
print("test_batch is loading...")
testXtr = unpickle("test_batch")
for i in range(0, 10000):
img = np.reshape(testXtr['data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0)
picName = 'test/' + str(testXtr['labels'][i]) + '_' + str(i) + '.jpg'
imageio.imwrite(picName, img)
print("test_batch loaded.")