seq2seq模型之raw_rnn

本文是seq2seq模型的第二篇,主要是通过raw_rnn来实现seq2seq模型。
github地址是:https://github.com/zhuanxuhit/nd101
原文地址:https://github.com/ematvey/tensorflow-seq2seq-tutorials/blob/master/2-seq2seq-advanced.ipynb

import numpy as np
import tensorflow as tf
import helper

tf.reset_default_graph()
sess = tf.InteractiveSession()
PAD = 0
EOS = 1
# UNK = 2
# GO  = 3

vocab_size = 10
input_embedding_size = 20

encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units * 2
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')

encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')
embeddings = tf.Variable(tf.truncated_normal([vocab_size, input_embedding_size], mean=0.0, stddev=0.1), dtype=tf.float32)
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

encoder

此处定义的encoder和第一篇中的不同,需要注意

from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple
encoder_cell = LSTMCell(encoder_hidden_units)

关于LSTMCell和 BasicLSTMCell 的区别,这个可以去tf的官网看说明:

It(BasicLSTMCell) does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.

这个projection layer就是在输出ouput上加了一层fc,将其作为lstm的输出

cell clipping ?

peep-hole connections?

((encoder_fw_outputs,
  encoder_bw_outputs),
 (encoder_fw_final_state,
  encoder_bw_final_state)) = (
    tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                    cell_bw=encoder_cell,
                                    inputs=encoder_inputs_embedded,
                                    sequence_length=encoder_inputs_length,
                                    dtype=tf.float32, time_major=True)
    )

此处 encoder_cell 的形状是:[max_time, batch_size, ...],

encoder_fw_outputs 的形状是:[max_time, batch_size, cell_fw.output_size]

sequence_length:如果没有提供,则默认长度就是 [0,max_time-1],如果提供了则取 [0,sequence_length-1]

state 的形状是 : [batch_size, cell_fw.state_size]

encoder_fw_outputs
<tf.Tensor 'bidirectional_rnn/fw/fw/TensorArrayStack/TensorArrayGatherV3:0' shape=(?, ?, 20) dtype=float32>
encoder_bw_outputs
<tf.Tensor 'ReverseSequence:0' shape=(?, ?, 20) dtype=float32>
encoder_fw_final_state
LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_2:0' shape=(?, 20) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?, 20) dtype=float32>)
encoder_cell.state_size
LSTMStateTuple(c=20, h=20)
  • encoder_fw_final_state.h is activations of hidden layer of LSTM cell
  • encoder_fw_final_state.c is final output, which can potentially be transfromed with some wrapper

此处 h 指的是内部的状态,而 c 则是 h 经过 activations 后的状态

encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)

encoder_final_state_c = tf.concat(
    (encoder_fw_final_state.c, encoder_bw_final_state.c), 1)

encoder_final_state_h = tf.concat(
    (encoder_fw_final_state.h, encoder_bw_final_state.h), 1)

encoder_final_state = LSTMStateTuple(
    c=encoder_final_state_c,
    h=encoder_final_state_h
)

在这个例子中,我们不会舍去 encoder_outputs,而是会将其用于attention机制

decoder

decoder_cell = LSTMCell(decoder_hidden_units)
encoder_max_time, batch_size = tf.unstack(tf.shape(encoder_inputs))

下一步我们要要决定 decoder 运行多少步后结束,有两种策略

  • Stop after specified number of unrolling steps
  • Stop after model produced

此处我们选择第一种,固定的步骤,在之前的一些tutorial中一般是len(encoder_input)+10,此处我们测试用则简单点使用len(encoder_input)+2

decoder_lengths = encoder_inputs_length + 3
# +2 additional steps, +1 leading <EOS> token for decoder inputs

Output projection

decoder的预测流程是:

output(t) -> output projection(t) -> prediction(t) (argmax) -> input embedding(t+1) -> input(t+1)

我们先指定输出projection的W和b

W = tf.Variable(tf.truncated_normal([decoder_hidden_units, vocab_size], 0, 0.1), dtype=tf.float32)
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32)

Decoder via tf.nn.raw_rnn

使用dynamic_rnn有一些限制,不能让我们自定义输入,像下图一样

raw_rnn

图片来自: http://www.wildml.com/2016/04/deep-learning-for-chatbots-part-1-introduction/

assert EOS == 1 and PAD == 0

eos_time_slice = tf.ones([batch_size], dtype=tf.int32, name='EOS')
pad_time_slice = tf.zeros([batch_size], dtype=tf.int32, name='PAD')

eos_step_embedded = tf.nn.embedding_lookup(embeddings, eos_time_slice)
pad_step_embedded = tf.nn.embedding_lookup(embeddings, pad_time_slice)

对于标准的 tf.nn.dynamic_rnn ,其输入 (t, ..., t+n) 需要事先作为一个 Tensor 输入,其动态 "Dynamic" 的含义是 n 的大小在每个batch中可以改变

此时如果我们希望一些更加复杂的机制,如每个cell的输出作为下一个的输入,或者实现 soft attention ,就没有办法了,这个时候我们就转向 tf.nn.raw_rnn 函数了

tf.nn.raw_rnn 最重要的就是 loop_fn 函数的编写,loop_fn做了一个映射

(time, previous_cell_output, previous_cell_state, previous_loop_state) -> (elements_finished, input, cell_state, output, loop_state).

上面转换的时机是在调用rnncell之前,准备好输入

loop_fn调用的时机有2个:

  1. Initial call at time=0 to provide initial cell_state and input to RNN.
  2. Transition call for all following timesteps where you define transition between two adjacent steps.

下面分别定义

def loop_fn_initial():
    initial_elements_finished = (0 >= decoder_lengths)  # all False at the initial step
    initial_input = eos_step_embedded
    initial_cell_state = encoder_final_state
    initial_cell_output = None
    initial_loop_state = None  # we don't need to pass any additional information
    return (initial_elements_finished,
            initial_input,
            initial_cell_state,
            initial_cell_output,
            initial_loop_state)
# (time, previous_cell_output, previous_cell_state, previous_loop_state) -> 
#     (elements_finished, input, cell_state, output, loop_state).
def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):

    def get_next_input():
        output_logits = tf.add(tf.matmul(previous_output, W), b) # projection layer
        # [batch_size, vocab_size]
        prediction = tf.argmax(output_logits, axis=1)
        next_input = tf.nn.embedding_lookup(embeddings, prediction)
        # [batch_size, input_embedding_size]
        return next_input
    
    elements_finished = (time >= decoder_lengths) # this operation produces boolean tensor of [batch_size]
                                                  # defining if corresponding sequence has ended

    finished = tf.reduce_all(elements_finished) # -> boolean scalar
    input = tf.cond(finished, lambda: pad_step_embedded, get_next_input)
    # input shape [batch_size,input_embedding_size]
    state = previous_state
    output = previous_output
    loop_state = None

    return (elements_finished, 
            input,
            state,
            output,
            loop_state)

上面我们分别定义了两个loop_fn,下面我们会将其合并为一个

def loop_fn(time, previous_output, previous_state, previous_loop_state):
    if previous_state is None:    # time == 0
        assert previous_output is None and previous_state is None
        return loop_fn_initial()
    else:
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(decoder_cell, loop_fn)
decoder_outputs = decoder_outputs_ta.stack()
decoder_outputs # hidden_size = 40
<tf.Tensor 'TensorArrayStack/TensorArrayGatherV3:0' shape=(?, ?, 40) dtype=float32>

为了对输出做最后的projection操作,我们需要reshape操作

[max_steps, batch_size, hidden_dim] to [max_steps*batch_size, hidden_dim]

decoder_max_steps, decoder_batch_size, decoder_dim = tf.unstack(tf.shape(decoder_outputs))
decoder_outputs_flat = tf.reshape(decoder_outputs, (-1, decoder_dim))
decoder_logits_flat = tf.add(tf.matmul(decoder_outputs_flat, W), b)
decoder_logits = tf.reshape(decoder_logits_flat, (decoder_max_steps, decoder_batch_size, vocab_size))
decoder_prediction = tf.argmax(decoder_logits, 2)

Optimizer

rnn的输出shape是:[max_time, batch_size, hidden_units], 通过一个FC(projection layer)变换为[max_time, batch_size, vocab_size],vocab_size是固定的,max_time and batch_size 是动态的

print(decoder_targets)
Tensor("decoder_targets:0", shape=(?, ?), dtype=int32)
print(tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32))
Tensor("one_hot:0", shape=(?, ?, 10), dtype=float32)
stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),
    logits=decoder_logits,
)

loss = tf.reduce_mean(stepwise_cross_entropy)
train_op = tf.train.AdamOptimizer().minimize(loss)
sess.run(tf.global_variables_initializer())

模型训练

跟第一篇中的训练一致

batch_size = 100

batches = helper.random_sequences(length_from=3, length_to=8,
                                   vocab_lower=2, vocab_upper=10,
                                   batch_size=batch_size)

print('head of the batch:')
for seq in next(batches)[:10]:
    print(seq)
head of the batch:
[3, 8, 9, 4, 9, 4]
[2, 3, 7, 2, 5, 2]
[4, 8, 8, 2]
[9, 6, 2, 5, 8, 3, 5, 3]
[9, 7, 2, 6]
[7, 6, 2]
[4, 9, 6, 6]
[8, 4, 7, 8, 9, 8, 7]
[4, 3, 6, 7, 4, 2]
[6, 5, 8]
def next_feed():
    batch = next(batches)
    encoder_inputs_, encoder_input_lengths_ = helper.batch(batch)
    decoder_targets_, _ = helper.batch(
        [(sequence) + [EOS] + [PAD] * 2 for sequence in batch]
    )
    return {
        encoder_inputs: encoder_inputs_,
        encoder_inputs_length: encoder_input_lengths_,
        decoder_targets: decoder_targets_,
    }
loss_track = []
max_batches = 3001
batches_in_epoch = 1000

try:
    for batch in range(max_batches):
        fd = next_feed()
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)

        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(decoder_prediction, fd)
            for i, (inp, pred) in enumerate(zip(fd[encoder_inputs].T, predict_.T)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
            print()

except KeyboardInterrupt:
    print('training interrupted')
batch 0
  minibatch loss: 2.3008058071136475
  sample 1:
    input     > [2 7 7 4 0 0 0 0]
    predicted > [8 8 0 0 8 0 0 0 0 0 0]
  sample 2:
    input     > [5 2 6 3 2 4 4 0]
    predicted > [8 9 9 9 9 9 9 9 9 9 0]
  sample 3:
    input     > [7 8 3 6 0 0 0 0]
    predicted > [5 1 8 1 4 0 0 0 0 0 0]

batch 1000
  minibatch loss: 0.7816314697265625
  sample 1:
    input     > [4 7 4 4 0 0 0 0]
    predicted > [4 4 4 4 1 0 0 0 0 0 0]
  sample 2:
    input     > [4 5 4 5 5 3 7 0]
    predicted > [4 5 5 5 5 5 5 1 0 0 0]
  sample 3:
    input     > [6 6 7 7 0 0 0 0]
    predicted > [6 7 7 7 1 0 0 0 0 0 0]

batch 2000
  minibatch loss: 0.37288352847099304
  sample 1:
    input     > [7 3 2 3 0 0 0 0]
    predicted > [7 3 2 3 1 0 0 0 0 0 0]
  sample 2:
    input     > [9 6 3 9 6 0 0 0]
    predicted > [9 9 6 6 6 1 0 0 0 0 0]
  sample 3:
    input     > [8 4 6 6 2 0 0 0]
    predicted > [8 4 6 6 2 1 0 0 0 0 0]

batch 3000
  minibatch loss: 0.22249329090118408
  sample 1:
    input     > [6 3 5 0 0 0 0 0]
    predicted > [6 3 5 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [3 2 5 5 2 4 0 0]
    predicted > [3 2 5 5 2 4 1 0 0 0 0]
  sample 3:
    input     > [2 7 3 0 0 0 0 0]
    predicted > [2 7 3 1 0 0 0 0 0 0 0]
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(loss_track)
print('loss {:.4f} after {} examples (batch_size={})'.format(loss_track[-1], len(loss_track)*batch_size, batch_size))
loss 0.2228 after 300100 examples (batch_size=100)
output_45_1.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,425评论 4 361
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,058评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,186评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,848评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,249评论 3 286
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,554评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,830评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,536评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,239评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,505评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,004评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,346评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,999评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,060评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,821评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,574评论 2 271
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,480评论 2 267

推荐阅读更多精彩内容