1.背景
在深度学习的开源框架中,Tensorflow是最热门的框架之一,相信很多同学已经有了不同程度的学习和了解。但站长在平时的沟通发现,很多同学反应不知道怎么使用自己训练好的模型进行预测,不知道怎么继续接着之前训练了多个轮次的模型进行训练,不知道怎么生成工业化场景里可上线的模型文件等等。 因此,站长会写一个针对Tensorflow的模型保存和加载的系列文章,为大家解决相关问题。
1.1 模型文件介绍
Tensorflow保存模型的时候会生成三个文件,分别是meta file,index file,data file。
meta file 这个文件是描述图结构,包括GraphDef, SaverDef等。值得注意的是,在Tensorfow中图和变量是分开的,关于图结构的信息主要保存在meta file中。
index file 这个文件是关于tensor的索引文件,key就是tensor的名字,value就是序列化后的BundleEntryProto。
data file 这个文件保存了所有变量的值。
1.2 模型的保存示例代码一
这里先简单地实现了一个例子,实现的是初始化了随机变量v1和v2,并将变量v1和v2相加获得变量v3。注意:这里保存模型的目录是result而不是model.ckpt。而这里的model.ckpt是模型文件的前缀,如果想要在同一个目录下保存多个模型的话,可以通过修改这个前缀达成目的。
# -*- coding: utf-8 -*-
import tensorflow as tf
v1 = tf.get_variable("v1", shape=[10], initializer=tf.random_normal_initializer)
v2 = tf.get_variable("v2", shape=[10], initializer=tf.random_normal_initializer)
v3 = tf.add(v1,v2, name="v3")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
print("Start initialing model parameters")
sess.run(init_op)
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
print("v3 : %s" % v3.eval())
# Save the variables to disk.
save_path = saver.save(sess,