Tensorflow 模型的保存和加载

今天写测试程序的时候发现预测结果错到离谱,眼看又要哭晕在厕所的我,又仔细检查了一遍训练程序,发现是模型保存错了 -_-||| ,把saver放在了循环的外面,这就很尴尬了。。。改完又可以给自己放个小长假,让程序自己慢慢重跑一次吧啦啦啦小魔仙全身变。。。

一、模型的保存

分两步。

1.在计算图之后(所有变量节点都创建好之后),定义一个 saver 对象。
2.开启 Session ,利用 saver 保存模型。

首先,在定义计算图之后,开启会话之前,定义一个 saver 对象。

saver = tf.train.Saver()

Saver 类在初始化时,有一些常用的参数:

  • var_list 默认为 None,即保存所有可保存的对象。
  • reshape为 True 时,表示从一个 checkpoint 中恢复参数时允许参数shape发生变化。(当我们reshape了一个变量又希望加载旧模型时,该操作就很有用。)
  • max_to_keep 自动保存 max_to_keep 个模型,默认值为 5。(也就是说,尽管程序每个 step 保存一次模型,但实际上只会保存最近的5次。)
  • keep_checkpoint_every_n_hours 用于指定保留 Checkpoints 文件的时间,默认为 10000 小时。

然后,在开启 Session 会话后,利用 saver 保存模型:

# 开启会话
with tf.Session() as sess:
  sess.run(init)
  ***省略代码***

  #保存模型
  # 注意:路径最后一项是模型名字,加载时模型路径应该为‘save/model/’
  saver.save(sess,'save/model/model',global_step=step)
  • 第一个参数 sess 是定义的会话,记录了这次训练中所有变量的值。
  • 第二个参数是模型保存的路径和名字。
  • 第三个参数用于把训练时的迭代次数加入文件名。

例如:

# 模型的文件名:my_model-1
saver.save(sess,'save/model',global_step=1)
# 模型的文件名:my_model-1000
saver.save(sess,'save/model',global_step=1000)

保存之前要记住,saver自动保存max_to_keep个模型(默认为5个),多了也没用,会自动忽略哒~

下面是几种常用的使用情况:

使用1 每次迭代保存一个模型
for i in range(2000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  saver.save(sess, './model/model', global_step=i+1)
使用2 每100次迭代保存一个模型
# 一共迭代num_step次
for i in range(num_step):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  if i%100 == 0:
    saver.save(sess, './model/model', global_step=i+1)
使用3 保存结果最好的模型
# 一共迭代num_step次
 max_acc = 0
for i in range(num_step):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: batch_xs, y_: batch_ys})
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  if val_acc>max_acc:
    max_acc = val_acc
    saver.save(sess, './model/model', global_step=i+1)
使用4 保存结果最好的3个模型
saver = tf.train.Saver(max_to_keep=3)
***省略代码***

# 一共迭代num_step次
 max_acc = 0
for i in range(num_step):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: batch_xs, y_: batch_ys})
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  if val_acc>max_acc:
    max_acc = val_acc
    saver.save(sess, './model/model', global_step=i+1)

模型路径下会出现4个文件:
checkpoint 保存目录下所有模型的文件列表
.index / .data 保存模型所有参数
.meta 保存计算图

'./model ' 路径下所保存的模型
'./model ' 路径下所保存的模型

二、模型的加载

模型恢复用的是restore(sess, save_path)函数,它需要两个参数,sess表示当前会话,之前保存的结果将被加载入这个会话,save_path指的是保存的模型路径。如:

# 加载模型参数
saver.restore(sess, "model/model-xxxx")  # xxxx是指定的加载模型,注意这里不用加模型的后缀名

注意:这里只加载了模型的所有参数,需要重新定义计算图。如果不想重新定义计算图,也可以直接加载持久化的计算图:

# 加载计算图
saver =tf.train.import_meta_graph("Model/model.ckpt.meta") 

若不指定加载模型,可以直接获得训练过程中最后保存的模型,以下两种方法可以实现获得最近一次保存的模型:
获得最近一次保存的模型 方法一
我们可以使用tf.train.latest_checkpoint()函数来自动获取最后一次保存的模型。如:

model = tf.train.latest_checkpoint('model/')  # 保存模型所在的路径
print(model)  
# ./model\model.ckpt-47557
saver.restore(sess,model)

获得最近一次保存的模型 方法二
我们可以使用tf.train.get_checkpoint_state()函数来自动获取最后一次保存的模型。如:

ckpt = tf.train.get_checkpoint_state('./model')
print(ckpt)
# model_checkpoint_path: "./model\\model.ckpt-47557"
# all_model_checkpoint_paths: "./model\\model.ckpt-40992"
# all_model_checkpoint_paths: "./model\\model.ckpt-45218"
# all_model_checkpoint_paths: "./model\\model.ckpt-47557"
print(ckpt.model_checkpoint_path)
# './model\\model.ckpt-47557'
saver.restore(sess, ckpt.model_checkpoint_path)

Reference
Tensorflow模型的保存与恢复
tensorflow模型保存与加载
TensorFlow模型保存和提取方法