tf.train.Saver() 与tf.train.import_meta_graph区别

4100阅读 0评论2020-04-13 wwm
分类:Python/Ruby

一、tf.train.Saver()
(1). tf.train.Saver() 是用来保存tensorflow训练模型的,默认保存全部参数
(2). 用来加载参数,   
        
:只加载存储在data中的权重和偏置项等需要训练的参数,其他一律不加载,
包括meta文件中的图

模型文件:

.ckpt文件:是旧版本的输出saver.save(sess),相当于你的.ckpt-data
“checkpoint”:文件仅用于告知某些TF函数,这是最新的检查点文件。
.ckpt-meta:包含元图,即计算图的结构,没有变量的值(基本上你可以在tensorboard / graph中看到)。
.ckpt-data:包含所有变量的值,没有结构。
.ckpt-index:可能是内部需要的某种索引来正确映射前两个文件,它通常不是必需的

二、tf.train.import_meta_graph(".meta文件")
加载计算图。一般用不到,有时候,想重复使用前面的计算或者代码,可调用import_meta_graph来复用。经常使用恢复参数即可。


注意无论是参数变量还是计算图 都调用save保存。

参考下面列子

点击(此处)折叠或打开

  1. # 连同图结构一同加载
  2. ckpt = tf.train.get_checkpoint_state('./model/')
  3. saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
  4. with tf.Session() as sess:
  5.  saver.restore(sess,ckpt.model_checkpoint_path)
  6.      
  7. # 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
  8. # 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
  9. saver = tf.train.Saver()
  10. with tf.Session() as sess:
  11.  ckpt = tf.train.get_checkpoint_state('./model/')
  12.  saver.restore(sess,ckpt.model_checkpoint_path)



上一篇:Tensorflow_实现断点续训功能
下一篇:tensorflow张量维度操作