tensorflow,存储读取数据结构剖析与合并多个graph,看不懂你掐死我

我又一次开始了“看不懂你掐死我系列”。标题名称是仿照知乎的一篇介绍傅里叶变换的文章起的。当时看完了觉得还真看懂了。可是关上网页再自己想的时候,就有想掐死博主的冲动~~ 为了致敬,这里贴出原文章,大家共勉。

网上下载的图片不认识,如有冒犯请联系我换图片

抄袭标题:看不懂傅里叶变换就掐死他

这段时间做训练的时候需要分步训练不同的网络结构,最后把所有训练好的graph合并成一个大graph,前后接起来并且重新定义输入和输出再继续训练,这样分步先训练小网络再合成大网络的话效果会好一点,收敛的也会快一些。那么有个问题,怎么把训练好的好几个graph恢复训练参数再合并到一起呢?Tensorflow到底能不能这么做?如果能,那应该怎么做?在读了这篇这篇知乎,和搜了无数个stackoverflow上的例子之后,终于有了答案。


要知道我们需要把每个pretrain的网络的结构和参数全都读进去,再把它们合并在一起。先不说合并的事,读取参数和结构就是个问题。比如下边这几个stackoverflow的帖子。12345。他们都用了不同的读取方法。但是到底读取的是什么?有没有达到我们预期的目的却不清楚。所以我意识到先要把tensorflow的内部结构搞清楚,看看存有什么东西,再看看存储和读取的方式。先来看结构。

Tensorflow的内部结构:

上面的这篇知乎都说的挺清楚的,我就捡这最重要的总结一下。

我们都知道tensorflow里有graph,graph的节点就是运算operation。这个用tensorboard可视化可以看到。比如下面这就是个简单的graph。

graph示例

这个graph在tensorflow里实际的存储方式是被序列化以后,以Protocol Buffer的形式存储的。这里有中文的对protobuf的介绍,是google开发的。

graph序列化的protobuf叫做graphDef,就是define graph的意思,一个graph的定义。这个graphDef可以用tf.train.write_graph()/tf.Import_graph_def()来写入和导出。上面stackoverflow里就有人用这个方法。然而graphDef里面其实是没有存储变量的,但是可以存常量,就是constant。可以用一种叫freeze_graph的工具把变量替换成常量,这里有官方的介绍。一般来说没有必要这么做,因为既然存了网络,肯定有变量的信息,虽然不在graphDef里面,但是肯定在别的地方。其实它存在collectionDef里。还有一些其他的Def,所以干脆归纳一下:

MetaGraph - MetaInfoDef 这个是存metadata的,像版本信息啊,用户信息啥的

                    - GraphDef 上面说的就是这个GraphDef

                    - SaverDef 这个就是tf.train.Saver的saver

                    - CollectionDef

这些Def的数据都存在一个叫MetaGraph的文件里。这个MetaGraph有官方介绍

最后面的collectionDef就是各种集合。每个集合里都是1对多的key/value pairs。你也可以把你想要的变量存进某个即合理,用tf.add_to_collection(collection_name,变量)就行。然后再用tf.get_collection()取出来。比如我有loss和train_op,就可以:

tf.add_to_collection("training_collection",loss)

tf.add_to_collection("training_collection",train_op)

然后再用

Train_collect = tf.get_collection(“training_collection”)  #得到一个python list

list里面就是你之前存的东西。所以collection我的理解就是为了方便管理变量用的。

metagraph可以用export_meta_graph/Import_meta_graph来导入导出。

这里注意了,如果你用tf.Import_graph_def()导入graphDef的话,导入的东西一般是不能训练的。但是用Import_meta_graph来导入metagraph之后,就是导入了一个完整的结构,这时候是可以训练的

虽然能训练,metagraph里也有变量,但是都是起始值。也就是说我们之前训练的参数是没有导入的。这里训练等于是从头训练。实际的训练参数没有存在metagraph里,而是在data文件里。这个下面会提到。


说完了tensorflow的结构,再说说存储的方式。看完这节,你应该完全知道什么api是用来读什么的了。


存储与读取:

上面那篇中文知乎恰好总结了这些。一般存读有3个API:

tf.train.Saver()/saver.restore()

export_meta_graph/Import_meta_graph

tf.train.write_graph()/tf.Import_graph_def()

后两个上一节都见过了。现在说说第一个。

我平时常用的只有第一个tf.train.Saver()和saver.restore()。我也看到很多代码里这么写。但是有一点很坑爹的是tf.train.saver.save() 什么都保存。但是在恢复图时,tf.train.saver.restore() 只恢复 Variable,如果要从MetaGraph恢复图,需要使用 import_meta_graph。看明白了吗?saver.save()和saver.restore()保存和读取的东西不!一!样!也就是说如果我想重组graph,要么用Import_meta_graph来导入graph,之后再saver.restore();要么就从新建立graph,把tensor传入结构的过程再写一遍,然后再saver.restore()。不然连变量名都找不到肯定会报错。

说道存储,我们必须得看看存储文件的格式。如果你用saver.save()保存的话(好像也只有这一种方法),打开你的保存文件夹,你会看到4种后缀名的文件(events开头的不算,那是tf.summary生成给tensorboard用的),分别是:

checkpoint - 就是一个账本文件,可以使用高级帮助程序来加载不同的时间保存的chkp文件。没什么用

.meta - 保存压缩后的Metagraph的protobufs,其实就是Metagraph。

.index - 包含一个不可变的键值表,用于链接序列化的张量名称以及在chkp.data文件中查找其数据的位置,也没存什么实际东西

.data - 这个里面才是存了训练后的参数。通常比.meta要大。有的时候有多个data文件用于共享或创建多个训练的时间戳。

其中.data文件的名字一般都是这种格式的:

<prefix>-<global_step>.data-<shard_index>-of-<number_of_shards>.

比如:

存储名的例子

所以saver.restore()的时候其实是restore的.data文件。当然在restore之前可以用tf.train.latest_checkpoint()来得到最后一次存储点。还有一点是在saver.save()和restore的时候,那个文件对象是xxx.ckpt。但实际上在存储文件夹里你找不到xxx.ckpt文件。这个也是正常的。官方文档有说.ckpt文件其实是隐性的的。所以除非你文件名字输入错了,不然不用担心读错文件。

下面结合我的实例再看看怎么合并graph。


实例:

先稍微介绍一下网络的结构。我有四个网络结构。其中3个网络是平行的,这里就叫p1,p2和p3吧。最后一个网络是微调用的,就叫m吧。这个m会得到3个网络的输出,合并在一起作为m的输入,输入到m,最后得到最终结果。为了方便理解我画了个图。

总体网络结构示意图

如果直接训练这么大的网络,收敛起来一定很费劲,有可能某一个网络落到一个local minimum就出不去了。所以我们把p1,p2,p3拿出来单独训练,每次只训练一个。

我分别用数据训练这3个网络。这个训练阶段算是pretrain。待到三个网络都稳定的时候,我把它们的输出结果加在一起,输入到第四个网络里训练整个网络。

官方文件称feed_dicts是效率最低的方法,所以我们改用的tfrecord和dataset api来读取文件。如果你不清楚这是啥,可以参看我们办公室博导的简书,这家伙可厉害了~

现在有两个问题,1是用Import_meta_graph导入metagraph的方法没法合并graph,因为我写的数据导入之后拿不出来(或者说我不知道怎么拿出来,可能有api可以取出来)。p1,p2,p3的输出数据是要手动连接的。import_graph_def()也可以设置input,output mapping,但是我这里没有tf.placeholder。我必须拿到一个从p1,p2,p3合成出来的tensor,再塞到m里去。所以我选择了用重建graph的方法。用

traindata, label = data_iterator(tfrecord_path).get_next() 

得到数据,再把traindata分别放入p1,p2,p3的架构中:

out_p1 = networkp1(trandata_p1)

网络结构有了,再restore参数:

full_path = tf.train.latest_checkpoint(model_ckp)

saver.restore(sess, full_path)

p2和p3也这么做。

三个全恢复了会得到三个output,再合并

m_data = out_p1 + out_p2 + out_p3

再输入m中:

output_m = networkm(m_data)

之后再做loss,bp,summary啥的,就可以训练了。

需要注意的是,别恢复错了graph。不要建3个session下分别用3个graph恢复,因为那样到

m_data = out_p1 + out_p2 + out_p3 #如果三个out是不同的graph,这里会报错

这一步会报错。说不同的graph出来的结果是不能相互运算的。大家必须是在同一个graph里才行。所以要建一个session,在这个session下挨个恢复:

with tf.session as sess: # 下面每个restore里不要单建 with tf.graph():... 

    # restore p1

    # restore p2

    # retore p3

    # ....

等于是把大家依次放进default graph里。再填上最后的m就ok了。



references:

https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125

https://zhuanlan.zhihu.com/p/31308381

https://www.tensorflow.org/api_guides/python/meta_graph#What_s_in_a_MetaGraph

https://www.jianshu.com/p/0f9f2bb962f4

stackoverflow:

https://stackoverflow.com/questions/41990014/load-multiple-models-in-tensorflow

https://stackoverflow.com/questions/45093688/how-to-understand-sess-as-default-and-sess-graph-as-default

https://stackoverflow.com/questions/49864234/tensorflow-restoring-variables-from-two-checkpoints-after-combining-two-graphs

https://stackoverflow.com/questions/49490262/combining-graphs-is-there-a-tensorflow-import-graph-def-equivalent-for-c

https://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,290评论 4 363
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,399评论 1 294
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,021评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,034评论 0 207
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,412评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,651评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,902评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,605评论 0 199
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,339评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,586评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,076评论 1 261
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,400评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,060评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,083评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,851评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,685评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,595评论 2 270

推荐阅读更多精彩内容

  • 这篇文章是针对有tensorflow基础但是记不住复杂变量函数的读者,文章列举了从输入变量到前向传播,反向优化,数...
    horsetif阅读 1,133评论 0 1
  • 在这篇tensorflow教程中,我会解释: 1) Tensorflow的模型(model)长什么样子? 2) 如...
    JunsorPeng阅读 3,333评论 1 6
  • 芦花编履求微暖,苇席浮棚惧雨侵。 今夕泛舟湖荡过,不谈风雪只弹琴。 注:1、芦花靴。用芦花、稻草编织的鞋,冬季穿,...
    真老实人_425a阅读 807评论 1 12
  • 每到夜晚入睡前,总是会突然一缕迷茫上心头。 是在回想什么? 还是在遗憾什么? 对于现在的自己是否很满意? 我给出否...
    心安若阅读 592评论 0 0
  • 一生中,有你,觉得幸福;没你,也很快乐。有时相见不如怀念;有时怀念不如相见。时过境迁,事过人非。曾经的执...
    甲子梅阅读 221评论 1 1