Tensorflow win10 c++ 运行 python训练出的模型

简介

由于生产环境使用windows、C++,而tensorflow模型训练使用python更为方便,因此存在需求:在windows环境使用tensorflow的c++接口载入训练好的tensorflow模型,并进行测试。类似的文档比较缺乏,并且由于tf本身一直在完善,相比现有的博客各个步骤都有进一步的简化,这里针对1.2.0版本梳理对应的最简单的一种流程:

  1. 利用tensorflow的python API定义、训练自己的模型
  2. 利用tensorflow的python API保存模型,并进一步将模型中的变量都转化为常量,通过这样“freeze graph”使得模型导出为一个文件,便于c++调用
  3. 编译tensorflow的源码来使用tensorflow的c++接口
  4. 在tensorflow的tutorrials Image Recognition 的基础上修改代码,利用模型进行测试。

利用tf的python API训练模型

这部分属于tensorflow的基础,官方文档getting started有相当详细的介绍和描述,在此不做赘述。值得注意的是tf的命名方式,在python代码中的变量名和在tf的graph中的变量名是两个概念,因此至少针对输入输出要定义tf的graph中的变量名,定义变量名的语法类似loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')。此外也可以利用tf.name_scope
来规划命名。

导出tf模型并freeze graph

这部分有官方工具的代码freeze_graph.py,对应的博客也很多。这里我推荐博客TensorFlow: How to freeze a model and serve it with a python API
  freeze graph就是把原本的图中的变量(卷积核、偏置)等都使用训练好的模型中的值来代替,变成常量。frozen graph的意义在于(freeze_graph.py的注释)

It's useful to do this when we need to load a single file in C++, especially in environments like mobile or embedded where we may not have access to the RestoreTensor ops and file loading calls that they rely on.

推荐的主要原因在于博客中使用方法saver = tf.train.Saver();last_chkp = saver.save(sess, 'results/graph.chkp')是最为简单的保存模型的方法,同时博客提供了freeze graph的代码,核心采用graph_util.convert_variables_to_constants 方法来进行freeze graph,使得不需要使用官方工具freeze_graph.py。对应freeze_graph的代码引用如下(其中注意到write使用参数‘wb'写为二进制):


import os, argparse

import tensorflow as tf
from tensorflow.python.framework import graph_util

dir = os.path.dirname(os.path.realpath(__file__))

def freeze_graph(model_folder):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path
    
    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/frozen_model.pb"

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    # NOTE: this variable is plural, because you can have multiple output nodes
    output_node_names = "Accuracy/predictions"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True
    
    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            input_graph_def, # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_folder", type=str, help="Model folder to export")
    args = parser.parse_args()

    freeze_graph(args.model_folder)

编译源码来使用tf的c++ API

编译源码的方式官方有文档Installing TensorFlow from Sources,其中有段:

We don't officially support building TensorFlow on Windows; however, you may try to build TensorFlow on Windows if you don't mind using the highly experimental Bazel on Windows or TensorFlow CMake build.

在两种方案中,我选择采用cmake,理由是相对来说环境配置更为容易,但可能使用google自己的bazel相对支持度更高。
  参考官方readme一步一步来,值得注意的有两点,一个是git clone的时候推荐git对应的稳定版本的分支(直接master可能会有编译错误和未知bug);另一个是要用命令行进行编译,直接采用vs2015 IDE进行编译会出错C1060,原因应该是默认的编译器调用的不是native 64位的toolset,如何设置使得能够使用IDE直接编译调试的方法还没有找到。
  相比于官方的项目tf_tutorials_example_trainer.vcxproj,更有参考意义的项目是tf_label_image_example.vcxproj,对应的详尽官方教程Image Recognition,这个教程使用inception模型来进行识别,对应运行时可能需要修改图片和文件的路径才能正确输出结果。

修改代码实现自己的模型

教程源码提供了模型读取,图片读取,Label读取等核心步骤,修改对应代码进行编译能够很容易上手完成任务,下面贴一下保存图片的代码,总体是读取图片的逆向过程:

// Given an output tensor with 4d, reduce dim and output jpg image
Status SaveTensorToImageFile(const string& file_name, const Tensor* out_tensor) {
    auto root = tensorflow::Scope::NewRootScope();
    using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)

    auto output_image_data = tensorflow::ops::Reshape(root, *out_tensor, { 256, 256, 3 });
    auto output_image_data_cast = tensorflow::ops::Cast(root, output_image_data, tensorflow::DT_UINT8);
    auto output_image = tensorflow::ops::EncodeJpeg(root, output_image_data_cast);
    auto output_op = tensorflow::ops::WriteFile(root.WithOpName("output/image"), file_name/*"D:/tf_face/trained_model_fast/output.jpg"*/, output_image);
    string output_name = "output/image";
    // This runs the GraphDef network definition that we've just constructed, and
    // returns the results in the output tensor.
    tensorflow::GraphDef graph;
    TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));

    std::unique_ptr<tensorflow::Session> session(
        tensorflow::NewSession(tensorflow::SessionOptions()));
    TF_RETURN_IF_ERROR(session->Create(graph));
    Status writeResult = session->Run({}, {}, { output_name }, {});
    return writeResult;
}

代码中图片的尺寸可以自行定义,其中要注意的是c++中session->Run函数传入的参数无论是ops或是Tensor都是要使用tf定义的名字root.WithOpName("output/image")而不是c++代码中定义的局部变量output_op,以上在tf的CPU版本上流程走通。

参考链接

Tensorflow C++ API调用预训练模型和生产环境编译 (unix )
TensorFlow: How to freeze a model and serve it with a python API
TensorFlow CMake build
Tensorflow Tutorial Image Recognition

推荐阅读更多精彩内容