我们需要在模型训练完成后将模型保存到文件系统上,以便于我们后续的测试与部署。在训练大规模的网络时,为了减少在训练过程中被中断/宕机意外的损失,间歇性保存模型状态是个好习惯。
1.张量方式
model.save_weights('weights.ckpt')
这种保存与加载网络的方式最为轻量级,文件中仅保存参数张量的数值,并没有额外的结构参数。但它需要使用相同的网络结构才能够恢复网络状态,因此一般在拥有网络源文件的情况下使用
2.网络方式
通过Model.save(path)函数可以将模型的结构以及模型的参数保存到一个path文件上,在不需要网络源文件的条件下,通过keras.models.load_model(path)即可恢复网络结构和网络参数
model.save('model.h5')
此时通过model.h5文件即可恢复出网络的结构和状态:
net = tf.keras.models.load_model('model.h5')
net.summary()
model.h5同时保存了模型参数和网络结构信息,不需要提前创建模型即可直接从文件恢复出网络net对象
在学习别人的代码时,看到了如下代码
model_name = 'xxx'
backbone_name = 'xxx'
CKPT_PATH = "xxx"
t= time.strftime("%Y%m%d-%H%M%S", time.localtime())
model.save_weights(CKPT_PATH+'{}_{}_{}_model.h5'.format(t,model_name,backbone_name))
感觉这样很方便参加竞赛时看各个版本的模型
参考资料:天池学习