Tensorflow动态seq2seq使用总结(r1.3)

动机

其实差不多半年之前就想吐槽Tensorflow的seq2seq了(后面博主去干了些别的事情),官方的代码已经抛弃原来用静态rnn实现的版本了,而官网的tutorial现在还是介绍基于静态的rnn的模型,加bucket那套,看这里

tutorial.png

看到了吗?是legacy_seq2seq的。本来Tensorflow的seq2seq的实现相比于pytorch已经很复杂了,还没有个正经的tutorial,哎。
好的,回到正题,遇到问题解决问题,想办法找一个最佳的Tensorflow的seq2seq解决方案

学习的资料

  • 知名博主WildML给google写了个通用的seq2seq,文档地址Github地址。这个框架已经被Tensorflow采用,后面我们的代码也会基于这里的实现。但本身这个框架是为了让用户直接写参数就能简单地构建网络,因此文档没有太多参考价值,我们直接借用其中的代码构建自己的网络。
  • 俄罗斯小伙ematvey写的:tensorflow-seq2seq-tutorials,Github地址。介绍使用动态rnn构建seq2seq,decoder使用raw_rnn,原理和WildML的方案差不多。多说一句,这哥们当时也是吐槽Tensorflow的文档,写了那么个仓库当第三方的文档使,现在都400+个star了。真是有漏洞就有机遇啊,哈哈。

Tensorflow的动态rnn

先来简单介绍动态rnn和静态rnn的区别。
tf.nn.rnn creates an unrolled graph for a fixed RNN length. That means, if you call tf.nn.rnn with inputs having 200 time steps you are creating a static graph with 200 RNN steps. First, graph creation is slow. Second, you’re unable to pass in longer sequences (> 200) than you’ve originally specified.tf.nn.dynamic_rnn solves this. It uses a tf.While loop to dynamically construct the graph when it is executed. That means graph creation is faster and you can feed batches of variable size.

摘自Whats the difference between tensorflow dynamic_rnn and rnn?。也就是说,静态的rnn必须提前将图展开,在执行的时候,图是固定的,并且最大长度有限制。而动态rnn可以在执行的时候,将图循环地的复用。

一句话,能用动态的rnn就尽量用动态的吧

Seq2Seq结构分析

seq2seq.png

seq2seq由Encoder和Decoder组成,一般Encoder和Decoder都是基于RNN。Encoder相对比较简单,不管是多层还是双向或者更换具体的Cell,使用原生API还是比较容易实现的。难点在于Decoder:不同的Decoder对应的rnn cell的输入不同,比如上图的示例中,每个cell的输入是上一个时刻cell输出的预测对应的embedding。

attention.png

如果像上图那样使用Attention,则decoder的cell输入还包括attention加权求和过的context。

通过示例讲解

slot filling.png

下面通过一个用seq2seq做slot filling(一种序列标注)的例子讲解。完整代码地址:https://github.com/applenob/RNN-for-Joint-NLU

Encoder的实现示例

# 首先构造单个rnn cell
encoder_f_cell = LSTMCell(self.hidden_size)
encoder_b_cell = LSTMCell(self.hidden_size)
 (encoder_fw_outputs, encoder_bw_outputs),
 (encoder_fw_final_state, encoder_bw_final_state) = \
        tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_f_cell,
                                            cell_bw=encoder_b_cell,
                                            inputs=self.encoder_inputs_embedded,
                                            sequence_length=self.encoder_inputs_actual_length,
                                            dtype=tf.float32, time_major=True)

上面的代码使用了tf.nn.bidirectional_dynamic_rnn构建单层双向的LSTM的RNN作为Encoder。
参数:

  • cell_fw:前向的lstm cell
  • cell_bw:后向的lstm cell
  • time_major:如果是True,则输入需要是T×B×E,T代表时间序列的长度,B代表batch size,E代表词向量的维度。否则,为B×T×E。输出也是类似。

返回:

  • outputs:针对所有时间序列上的输出。
  • final_state:只是最后一个时间节点的状态。

一句话,Encoder的构造就是构造一个RNN,获得输出和最后的状态。

Decoder实现示例

下面着重介绍如何使用Tensorflow的tf.contrib.seq2seq实现一个Decoder。
我们这里的Decoder中,每个输入除了上一个时间节点的输出以外,还有对应时间节点的Encoder的输出,以及attention的context。

Helper

常用的Helper

  • TrainingHelper:适用于训练的helper。
  • InferenceHelper:适用于测试的helper。
  • GreedyEmbeddingHelper:适用于测试中采用Greedy策略sample的helper。
  • CustomHelper:用户自定义的helper。

先来说明helper是干什么的:参考上面提到的俄罗斯小哥用raw_rnn实现decoder,需要传进一个loop_fn。这个loop_fn其实是控制每个cell在不同的时间节点,给定上一个时刻的输出,如何决定下一个时刻的输入。
helper干的事情和这个loop_fn基本一致。这里着重介绍CustomHelper,要传入三个函数作为参数:

  • initialize_fn:返回finishednext_inputs。其中finished不是scala,是一个一维向量。这个函数即获取第一个时间节点的输入。
  • sample_fn:接收参数(time, outputs, state) 返回sample_ids。即,根据每个cell的输出,如何sample。
  • next_inputs_fn:接收参数(time, outputs, state, sample_ids) 返回 (finished, next_inputs, next_state),根据上一个时刻的输出,决定下一个时刻的输入。

BasicDecoder

有了自定义的helper以后,可以使用tf.contrib.seq2seq.BasicDecoder定义自己的Decoder了。再使用tf.contrib.seq2seq.dynamic_decode执行decode,最终返回:(final_outputs, final_state, final_sequence_lengths)。其中:final_outputstf.contrib.seq2seq.BasicDecoderOutput类型,包括两个字段:rnn_outputsample_id

回到示例

        # 传给CustomHelper的三个函数
        def initial_fn():
            initial_elements_finished = (0 >= decoder_lengths)  # all False at the initial step
            initial_input = tf.concat((sos_step_embedded, encoder_outputs[0]), 1)
            return initial_elements_finished, initial_input

        def sample_fn(time, outputs, state):
            # 选择logit最大的下标作为sample
            prediction_id = tf.to_int32(tf.argmax(outputs, axis=1))
            return prediction_id

        def next_inputs_fn(time, outputs, state, sample_ids):
            # 上一个时间节点上的输出类别,获取embedding再作为下一个时间节点的输入
            pred_embedding = tf.nn.embedding_lookup(self.embeddings, sample_ids)
            # 输入是h_i+o_{i-1}+c_i
            next_input = tf.concat((pred_embedding, encoder_outputs[time]), 1)
            elements_finished = (time >= decoder_lengths)  # this operation produces boolean tensor of [batch_size]
            all_finished = tf.reduce_all(elements_finished)  # -> boolean scalar
            next_inputs = tf.cond(all_finished, lambda: pad_step_embedded, lambda: next_input)
            next_state = state
            return elements_finished, next_inputs, next_state

        # 自定义helper
        my_helper = tf.contrib.seq2seq.CustomHelper(initial_fn, sample_fn, next_inputs_fn)

        def decode(helper, scope, reuse=None):
            with tf.variable_scope(scope, reuse=reuse):
                memory = tf.transpose(encoder_outputs, [1, 0, 2])
                attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                    num_units=self.hidden_size, memory=memory,
                    memory_sequence_length=self.encoder_inputs_actual_length)
                cell = tf.contrib.rnn.LSTMCell(num_units=self.hidden_size * 2)
                attn_cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell, attention_mechanism, attention_layer_size=self.hidden_size)
                out_cell = tf.contrib.rnn.OutputProjectionWrapper(
                    attn_cell, self.slot_size, reuse=reuse
                )
                # 使用自定义helper的decoder
                decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell=out_cell, helper=helper,
                    initial_state=out_cell.zero_state(
                        dtype=tf.float32, batch_size=self.batch_size))
                # 获取decode结果
                final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
                    decoder=decoder, output_time_major=True,
                    impute_finished=True, maximum_iterations=self.input_steps
                )
                return final_outputs

        outputs = decode(my_helper, 'decode')

Attntion

上面的代码,还有几个地方没有解释:BahdanauAttentionAttentionWrapperOutputProjectionWrapper

先从简单的开始:OutputProjectionWrapper即做一个线性映射,比如之前的cell的ouput是T×B×D,D是hidden size,那么这里做一个线性映射,直接到T×B×S,这里S是slot class num。wrapper内部维护一个线性映射用的变量:Wb

attention.png

BahdanauAttention是一种AttentionMechanism,另外一种是:BahdanauMonotonicAttention。具体二者的区别,读者请自行深入调查。关键参数:

  • num_units:隐层维度。
  • memory:通常就是RNN encoder的输出
  • memory_sequence_length=None:可选参数,即memory的mask,超过长度数据不计入attention。

继续介绍AttentionWrapper:这也是一个cell wrapper,关键参数:

  • cell:被包装的cell。
  • attention_mechanism:使用的attention机制,上面介绍的。
attention.png

memory对应公式中的h,wrapper的输出是s。

那么一个AttentionWrapper具体的操作流程如何呢?看官网给的流程:

AttentionWrapper.png

Loss Function

tf.contrib.seq2seq.sequence_loss可以直接计算序列的损失函数,重要参数:

  • logits:尺寸[batch_size, sequence_length, num_decoder_symbols]
  • targets:尺寸[batch_size, sequence_length],不用做one_hot。
  • weights[batch_size, sequence_length],即mask,滤去padding的loss计算,使loss计算更准确。

后记

这里只讨论了seq2seq在序列标注上的应用。seq2seq还广泛应用于翻译和对话生成,涉及到生成的策略问题,比如beam search。后面会继续研究。除了sample的策略,其他seq2seq的主要技术,本文已经基本涵盖,希望对大家踩坑有帮助。
完整代码:https://github.com/applenob/RNN-for-Joint-NLU

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,847评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,208评论 1 292
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,587评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,942评论 0 205
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,332评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,587评论 1 218
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,853评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,568评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,273评论 1 242
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,542评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,033评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,373评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,031评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,073评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,830评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,628评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,537评论 2 269

推荐阅读更多精彩内容