# 扣丁学堂浅谈将TensorFlow的模型网络导出为单个文件的方法

import tensorflow as tf

from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络

a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')

b = tf.Variable(4, dtype=tf.float32, name='b')

# 一定要给输出tensor取一个名字！！

# 转换Variable为constant，并将网络写入到文件

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

# 这里需要填入输出tensor的名字

graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])

tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

import tensorflow as tf

with tf.Session() as sess:

with open('./graph.pb', 'rb') as f:

graph_def = tf.GraphDef()

output = tf.import_graph_def(graph_def, return_elements=['out:0'])

print(sess.run(output))

[array([[ 7.],       [ 8.]], dtype=float32)]

import tensorflow as tf

from tensorflow.python.framework.graph_util import convert_variables_to_constants

a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')

b = tf.Variable(4, dtype=tf.float32, name='b')

input_tensor = tf.placeholder(tf.float32, name='input')

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])

tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

import tensorflow as tf

with tf.Session() as sess:

with open('./graph.pb', 'rb') as f:

graph_def = tf.GraphDef()

output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a')

print(sess.run(output))

[array([[ 11.],       [ 12.]], dtype=float32)]

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:

with open('./graph.pb', 'rb') as f:

graph_def = tf.GraphDef()

output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a')

print(sess.run(output, feed_dict={new_input:4}))

[array([[ 11.],       [ 12.]], dtype=float32)]

import tensorflow as tf

with tf.Session() as sess:

# 不使用'rb'模式

with open('./graph.pb', 'r') as f:

graph_def = tf.GraphDef()