TF 保存和加载模型

0.1 2017.09.09 16:11* 字数 795

A quick complete tutorial to save and restore Tensorflow models


Tensorflow 模型有两个主要的文件:

  1. meta graph
    This is a protocol buffer which saves the complete Tensorflow graph; i.e. all variables, operations, collections etc. This file has .meta extension.
    这是一个 protocol buffer(不会翻译)保存了完整的 Tensorflow 图,即所有变量、操作和集合等。拥有一个.meta的扩展名。
  2. checkpoint file
    This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt. However, Tensorflow has changed this from version 0.11. Now, instead of single .ckpt file, we have two files:

.data file is the file that contains our training variables and we shall go after it.

Along with this, Tensorflow also has a file named checkpoint which simply keeps a record of latest checkpoint files saved.
于此一起的,Tensorflow 还有一个文件叫做checkpoint只是单纯记录了最近的保存的ckeckpoint file





import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session(), 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.index
# my_test_model.meta
# checkpoint

If we are saving the model after 1000 iterations, we shall call save by passing the step count:
如果我们在1000次迭代后保存模型,我们把迭代次数传给保存函数, 'my_test_model',global_step=1000)

This will just append ‘-1000’ to the model name and following files will be created:


if you want to keep only 4 latest models and want to save one model after every 2 hours during training you can use max_to_keep and keep_checkpoint_every_n_hours like this.

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)


If you want to use someone else’s pre-trained model for fine-tuning, there are two things you need to do:

  1. 创造网络
    you can create the network by writing python code to create each and every layer manually as the original model. However, if you think about it, we had saved the network in .meta file which we can use to recreate the network using tf.train.import() function like this: saver = tf.train.import_meta_graph('my_test_model-1000.meta')
 saver = tf.train.import_meta_graph('my_test_model-1000.meta')

Remember, import_meta_graph appends the network defined in .meta file to the current graph. So, this will create the graph/network for you but we still need to load the value of the parameters that we had trained on this graph.

  1. 加载参数
    We can restore the parameters of the network by calling restore on this saver which is an instance of tf.train.Saver() class.
with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./'))
Theano & TF
Web note ad 1