Tensorflow模型保存和加载

本文翻译自http://cv-tricks.com,介绍了TensorFlow模型的保存和加载过程。包括Meta graph、Checkpoint文件的解释,以及如何使用tf.train.Saver进行模型的保存和恢复。详细阐述了在不同阶段如何操作,例如手动创建网络、注入训练数据、恢复权重,并提供了相关操作如tf.multiply和tf.stop_gradient的说明。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

一、什么是TensorFlow 模型:

1. Meta graph:协议的缓冲区,保存了完整的TensorFlow图,包括变量、操作、集合等。

2. Checkpoint file:二进制文件,包含所有权重、偏值、梯度、其他变量的值。它有一个扩展文件.ckpt。从0.11版本后,又扩展了两个文件:.data-00000-of-00001.index.data-00000-of-00001是一个包含训练过的变量的文件。.index应该就是.data-00000-of-00001对应的变量的下标。checkpoint记录最新保存的checkpoint file

0.11版本后:包括checkpoint.data-00000-of-00001.index.meta

0.11版本前:包括checkpoint.ckpt.meta

二、如何保存TensorFlow 模型:

训练完一个网络后,我们需要保存所有参数值和网络图(图就是.meta文件),以备之后使用。

1. tf.train.Saver()类的实例 saver = tf.train.Saver()

2. TensorFlow的变量只在session中存活,所以必须通过调用对象saversave(session,modelName)方法将session(包括checkpoint.data-00000-of-00001.index.meta)保存在modelName中。

Q:这些文件什么时候生成的,在执行save之前已经存在了?到底是session保存到modelName还是modelName保存到session
A:先将所有的.meta.data.index文件保存到session中,调用save()是将这个session保存到某个文件中,名字为modelName


3. 如果想要保存指定的变量或集合,可以在实例化Saver()时传一个包括这些变量的listdictionary,如tf.train.Saver([变量名1,变量名2])

三、如何导入预训练过的model(想要在已有的训练过的modelfine-tuning(微调))

1. 创建网络:

Python代码手动创建每一层作为原始模型,保存到.meta文件中,如saver =tf.train.import_meta_graph(.meta文件名) ##将手动创建的网络添加到.meta文件中

2. 加载网络训练的参数值

通过saver.restore()方法恢复网络的参数值,

saver.restore(session,tf.train.latest_checkpoint(‘./’)),之后张量w1w2就被恢复,可以被访问。

Q:什么叫restore

Arestore相当于加载,先重新生成一个session,将保存到modelName中的文件重新加载到新的session中。

 

四、使用restoremodel实践

1. 使用占位符注入所有的训练数据和超参数:

2. 定义一个将要恢复的测试操作:

3. 实例化一个saver对象用来保存所有变量:

4. 使用注入的值运行第2步定义的操作:

5. 保存网络图到session

 

1. 得到保存在图中的变量或张量或占位符:

2. 得到保存在图中的操作:

 如果只想要改变数据,不改变网络

1. 加载.meta文件,恢复权重:

2. 获得占位符变量,注入新的数据:

3. 获得操作:

4. 运行session

 

如果想要添加更多的操作或更多的层,然后训练,需要:

1. 加载.meta文件,恢复权重:

2. 获得占位符变量,注入新的数据:

3. 获得操作:

4. 在当前网络中添加操作:

5. 运行session

 

TensorFlow中的方法:

tf.multiply(x1,x2)  ##x1,x2点乘

tf.matmul(x1,x2)  ##矩阵相乘

tf.gradient(ys,xs)   ##计算ys关于xsx的偏导,返回xs中每个x对应的sum(dy/dx)

tf.stop_gradient(input)  ##阻挡节点BP的梯度,节点被阻挡后,这个节点上的梯度就无法再向前BP

tf.truncated_normal(shape,mean,stddev)  ##shape表示张量的维度,mean是均值,stddev是标准差,这个函数产生正态分布,产生的值如果与均值的差大于2倍的标准差,就重新生成。

tf.constant(value,dtype=None,shape=None,name=’Const’)  ##创建一个常量tensor,按照value赋值,shape指定形状。

总结:第一篇博客是翻译的英文网页的内容,正好这段时间在学习TensorFlow相关的东西,感谢CV-Tricks.com提供的详细讲解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值