翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
一、什么是TensorFlow 模型:
1. Meta graph:协议的缓冲区,保存了完整的TensorFlow图,包括变量、操作、集合等。
2. Checkpoint file:二进制文件,包含所有权重、偏值、梯度、其他变量的值。它有一个扩展文件.ckpt。从0.11版本后,又扩展了两个文件:.data-00000-of-00001和.index。.data-00000-of-00001是一个包含训练过的变量的文件。.index应该就是.data-00000-of-00001对应的变量的下标。checkpoint记录最新保存的checkpoint file。
0.11版本后:包括checkpoint、.data-00000-of-00001、.index、.meta
0.11版本前:包括checkpoint、.ckpt、.meta
二、如何保存TensorFlow 模型:
创训练完一个网络后,我们需要保存所有参数值和网络图(图就是.meta文件),以备之后使用。
1. 建tf.train.Saver()类的实例 saver = tf.train.Saver()
2. TensorFlow的变量只在session中存活,所以必须通过调用对象saver的save(session,modelName)方法将session(包括checkpoint、.data-00000-of-00001、.index、.meta)保存在modelName中。
Q:这些文件什么时候生成的,在执行save之前已经存在了?到底是session保存到modelName还是modelName保存到session?
A:先将所有的.meta和.data,.index文件保存到session中,调用save()是将这个session保存到某个文件中,名字为modelName。
3. 如果想要保存指定的变量或集合,可以在实例化Saver()时传一个包括这些变量的list或dictionary,如tf.train.Saver([变量名1,变量名2])。
三、如何导入预训练过的model(想要在已有的训练过的model上fine-tuning(微调))
1. 创建网络:
用Python代码手动创建每一层作为原始模型,保存到.meta文件中,如saver =tf.train.import_meta_graph(.meta文件名) ##将手动创建的网络添加到.meta文件中
2. 加载网络训练的参数值
通过saver.restore()方法恢复网络的参数值,
如saver.restore(session,tf.train.latest_checkpoint(‘./’)),之后张量w1和w2就被恢复,可以被访问。
Q:什么叫restore?
A:restore相当于加载,先重新生成一个session,将保存到modelName中的文件重新加载到新的session中。
四、使用restore的model实践
1. 使用占位符注入所有的训练数据和超参数:
2. 定义一个将要恢复的测试操作:
3. 实例化一个saver对象用来保存所有变量:
4. 使用注入的值运行第2步定义的操作:
5. 保存网络图到session
1. 得到保存在图中的变量或张量或占位符:
2. 得到保存在图中的操作:
如果只想要改变数据,不改变网络
1. 加载.meta文件,恢复权重:
2. 获得占位符变量,注入新的数据:
3. 获得操作:
4. 运行session:
如果想要添加更多的操作或更多的层,然后训练,需要:
1. 加载.meta文件,恢复权重:
2. 获得占位符变量,注入新的数据:
3. 获得操作:
4. 在当前网络中添加操作:
5. 运行session:
TensorFlow中的方法:
tf.multiply(x1,x2) ##x1,x2点乘
tf.matmul(x1,x2) ##矩阵相乘
tf.gradient(ys,xs) ##计算ys关于xs中x的偏导,返回xs中每个x对应的sum(dy/dx)
tf.stop_gradient(input) ##阻挡节点BP的梯度,节点被阻挡后,这个节点上的梯度就无法再向前BP了
tf.truncated_normal(shape,mean,stddev) ##shape表示张量的维度,mean是均值,stddev是标准差,这个函数产生正态分布,产生的值如果与均值的差大于2倍的标准差,就重新生成。
tf.constant(value,dtype=None,shape=None,name=’Const’) ##创建一个常量tensor,按照value赋值,shape指定形状。
总结:第一篇博客是翻译的英文网页的内容,正好这段时间在学习TensorFlow相关的东西,感谢CV-Tricks.com提供的详细讲解。