[RNN] Simple LSTM代码实现 & BPTT理论推导


前面我们介绍过CNN中普通的BP反向传播算法的推导,但是在RNN(比如LSTM)中,反向传播被称作BPTT(Back Propagation Through Time),它是和时间序列有关的。

Back Propagation Through Time

A few weeks ago I released some code on Github to help people understand how LSTM’s work at the implementation level. The forward pass is well explained elsewhere and is straightforward to understand, but I derived the backprop equations myself and the backprop code came without any explanation whatsoever. The goal of this post is to explain the so called backpropagation through time in the context of LSTM’s.

If you feel like anything is confusing, please post a comment below or submit an issue on Github.

Note: this post assumes you understand the forward pass of an LSTM network, as this part is relatively simple. Please read this great intro paper if you are not familiar with this, as it contains a very nice intro to LSTM’s. I follow the same notation as this paper so I recommend reading having the tutorial open in a separate browser tab for easy reference while reading this post.

Introduction (Simple LSTM)

LSTM Block

The forward pass of an LSTM node is defined as follows:

![][01]
[01]:http://latex.codecogs.com/png.latex?\g(t)%20&=&%20\phi(W_{gx}%20x(t)%20+%20W_{gh}%20h(t-1)%20+%20b_{g})%20\\%20i(t)%20&=&%20\sigma(W_{ix}%20x(t)%20+%20W_{ih}%20h(t-1)%20+%20b_{i})%20\\%20f(t)%20&=&%20\sigma(W_{fx}%20x(t)%20+%20W_{fh}%20h(t-1)%20+%20b_{f})%20\\%20o(t)%20&=&%20\sigma(W_{ox}%20x(t)%20+%20W_{oh}%20h(t-1)%20+%20b_{o})%20\\%20s(t)%20&=&%20g(t)%20%20i(t)%20+%20s(t-1)%20%20f(t)%20\\%20h(t)%20&=&%20s(t)%20*%20o(t)%20\

(:这里最后一个式子h(t)的计算,普遍认为s(t)前面还有一个tanh激活,然后再乘以o(t),不过 peephole LSTM paper中建议此处激活函数采用 f(x) = x,所以这里就没有用tanh(下同),可以参见Wiki - Long_short-term_memory上面所说的)

By concatenating the x(t) and h(t-1) vectors as follows:

![][02]
[02]:http://latex.codecogs.com/png.latex?x_c(t)%20=%20[x(t),%20h(t-1)]

we can rewrite parts of the above as follows:

![][03]
[03]:http://latex.codecogs.com/png.latex?\g(t)%20&=&%20\phi(W_{g}%20x_c(t)%20+%20b_{g})%20\\%20i(t)%20&=&%20\sigma(W_{i}%20x_c(t)%20+%20b_{i})%20\\%20f(t)%20&=&%20\sigma(W_{f}%20x_c(t)%20+%20b_{f})%20\\%20o(t)%20&=&%20\sigma(W_{o}%20x_c(t)%20+%20b_{o})

Suppose we have a loss l(t) that we wish to minimize at every time step t that depends on the hidden layer h and the label y at the current time via a loss function f:

![][04]
[04]:http://latex.codecogs.com/png.latex?l(t)%20=%20f(h(t),%20y(t))

where f can be any differentiable loss function, such as the Euclidean loss:

![][05]
[05]:http://latex.codecogs.com/png.latex?l(t)%20=%20f(h(t),%20y(t))%20=%20|%20h(t)%20-%20y(t)%20|^2

Our ultimate goal in this case is to use gradient descent to minimize the loss L over an entire sequence of length T

![][06]
[06]:http://latex.codecogs.com/png.latex?L%20=%20\sum_{t=1}^{T}%20l(t)

Let’s work through the algebra of computing the loss gradient:

![][07]
[07]:http://latex.codecogs.com/png.latex?\frac{dL}{dw}

where w is a scalar parameter of the model (for example it may be an entry in the matrix W_gx). Since the loss l(t) = f(h(t),y(t)) only depends on the values of the hidden layer h(t) and the label y(t), we have by the chain rule:

![][08]
[08]:http://latex.codecogs.com/png.latex?\frac{dL}{dw}%20=%20\sum_{t%20=%201}{T}%20\sum_{i%20=%201}{M}%20\frac{dL}{dh_i(t)}\frac{dh_i(t)}{dw}

where h_i(t) is the scalar corresponding to the i’th memory cell’s hidden output and M is the total number of memory cells. Since the network propagates information forwards in time, changing h_i(t) will have no effect on the loss prior to time t, which allows us to write:

![][09]
[09]:http://latex.codecogs.com/png.latex?\frac{dL}{dh_i(t)}%20=%20\sum_{s=1}T%20\frac{dl(s)}{dh_i(t)}%20=%20\sum_{s=t}T%20\frac{dl(s)}{dh_i(t)}

For notational convenience we introduce the variable L(t) that represents the cumulative loss from step tonwards:

![][10]
[10]:http://latex.codecogs.com/png.latex?L(t)%20=%20\sum_{s=t}^{s=T}%20l(s)

such that L(1) is the loss for the entire sequence. This allows us to rewrite the above equation as:

![][11]
[11]:http://latex.codecogs.com/png.latex?\frac{dL}{dh_i(t)}%20=%20\sum_{s=t}^T%20\frac{dl(s)}{dh_i(t)}%20=%20\frac{dL(t)}{dh_i(t)}

With this in mind, we can rewrite our gradient calculation as:

![][12]
[12]:http://latex.codecogs.com/png.latex?\frac{dL}{dw}%20=%20\sum_{t%20=%201}{T}%20\sum_{i%20=%201}{M}%20\frac{dL(t)}{dh_i(t)}\frac{dh_i(t)}{dw}

Make sure you understand this last equation. The computation of dh_i(t) / dw follows directly follows from the forward propagation equations presented earlier. We now show how to compute dL(t) / dh_i(t) which is where the so called backpropagation through time comes into play.

Backpropagation through time (BPTT)

Back Propagation Through Time

This variable L(t) allows us to express the following recursion:

![][13]
[13]:http://latex.codecogs.com/png.latex?L(t)%20=%20\begin{cases}%20l(t)%20+%20L(t+1)%20&%20\text{if}%20,%20t%20%3C%20T%20\%20l(t)%20&%20\text{if}%20,%20t%20=%20T%20\end{cases}

Hence, given activation h(t) of an LSTM node at time t, we have that:

![][14]
[14]:http://latex.codecogs.com/png.latex?\frac{dL(t)}{dh(t)}%20=%20\frac{dl(t)}{dh(t)}%20+%20\frac{dL(t+1)}{dh(t)}

Now, we know where the first term on the right hand side dl(t) / dh(t) comes from: it’s simply the elementwise derivative of the loss l(t) with respect to the activations h(t) at time t. The second term dL(t+1) / dh(t) is where the recurrent nature of LSTM’s shows up. It shows that the we need the next node’s derivative information in order to compute the current current node’s derivative information. Since we will ultimately need to compute dL(t) / dh(t) for all t = 1, 2, ... , T, we start by computing

![][15]
[15]:http://latex.codecogs.com/png.latex?\frac{dL(T)}{dh(T)}%20=%20\frac{dl(T)}{dh(T)}

and work our way backwards through the network. Hence the term backpropagation through time. With these intuitions in place, we jump into the code.

Code (Talk is cheap, Show me the code)

We now present the code that performs the backprop pass through a single node at time 1 <= t <= T. The code takes as input:

And computes:

whose values will need to be propagated backwards in time. The code also adds derivatives to:

since recall that we must sum the derivatives from each time step:

![][16]
[16]:http://latex.codecogs.com/png.latex?\frac{dL}{dw}%20=%20\sum_{t%20=%201}{T}%20\sum_{i%20=%201}{M}%20\frac{dL(t)}{dh_i(t)}\frac{dh_i(t)}{dw}

Also, note that we use:

where we recall that X_c(t) = [x(t), h(t-1)]. Without any further due, the code:

def top_diff_is(self, top_diff_h, top_diff_s):
    # notice that top_diff_s is carried along the constant error carousel
    ds = self.state.o * top_diff_h + top_diff_s
    do = self.state.s * top_diff_h
    di = self.state.g * ds
    dg = self.state.i * ds
    df = self.s_prev * ds

    # diffs w.r.t. vector inside sigma / tanh function
    di_input = (1. - self.state.i) * self.state.i * di
    df_input = (1. - self.state.f) * self.state.f * df
    do_input = (1. - self.state.o) * self.state.o * do
    dg_input = (1. - self.state.g ** 2) * dg

    # diffs w.r.t. inputs
    self.param.wi_diff += np.outer(di_input, self.xc)
    self.param.wf_diff += np.outer(df_input, self.xc)
    self.param.wo_diff += np.outer(do_input, self.xc)
    self.param.wg_diff += np.outer(dg_input, self.xc)
    self.param.bi_diff += di_input
    self.param.bf_diff += df_input
    self.param.bo_diff += do_input
    self.param.bg_diff += dg_input

    # compute bottom diff
    dxc = np.zeros_like(self.xc)
    dxc += np.dot(self.param.wi.T, di_input)
    dxc += np.dot(self.param.wf.T, df_input)
    dxc += np.dot(self.param.wo.T, do_input)
    dxc += np.dot(self.param.wg.T, dg_input)

    # save bottom diffs
    self.state.bottom_diff_s = ds * self.state.f
    self.state.bottom_diff_x = dxc[:self.param.x_dim]
    self.state.bottom_diff_h = dxc[self.param.x_dim:]

Details

The forward propagation equations show that modifying s(t) affects the loss L(t) by directly changing the values of h(t) as well as h(t+1). However, modifying s(t) affects L(t+1) only by modifying h(t+1). Therefore, by the chain rule:

![][17]
[17]:http://latex.codecogs.com/png.latex?\\frac{dL(t)}{ds_i(t)}%20=%20\frac{dL(t)}{dh_i(t)}%20\frac{dh_i(t)}{ds_i(t)}%20+%20\frac{dL(t)}{dh_i(t+1)}%20\frac{dh_i(t+1)}{ds_i(t)}%20\\\=%20\frac{dL(t)}{dh_i(t)}%20\frac{dh_i(t)}{ds_i(t)}%20+%20\frac{dL(t+1)}{dh_i(t+1)}%20\frac{dh_i(t+1)}{ds_i(t)}%20\\\=%20\frac{dL(t)}{dh_i(t)}%20\frac{dh_i(t)}{ds_i(t)}%20+%20\frac{dL(t+1)}{ds_i(t)}%20\\\%20=%20\frac{dL(t)}{dh_i(t)}%20\frac{dh_i(t)}{ds_i(t)}%20+%20[\texttt{top_diff_s}]_i%20\

Since the forward propagation equations state:

![][18]
[18]:http://latex.codecogs.com/png.latex?h(t)%20=%20s(t)%20*%20o(t)

we get that:

![][19]
[19]:http://latex.codecogs.com/png.latex?\frac{dL(t)}{dh_i(t)}%20%20\frac{dh_i(t)}{ds_i(t)}%20=%20o_i(t)%20%20[\texttt{top_diff_h}]_i

Putting all this together we have:

ds = self.state.o * top_diff_h + top_diff_s

The rest of the equations should be straightforward to derive, please let me know if anything is unclear.


Test LSTM Network

代码 其是通过自己实现 lstm 网络来逼近一个序列,y_list = [-0.5, 0.2, 0.1, -0.5],测试结果如下:

cur iter:  0
y_pred[0] : 0.041349
y_pred[1] : 0.069304
y_pred[2] : 0.116993
y_pred[3] : 0.165624
loss:  0.753483886253
cur iter:  1
y_pred[0] : -0.223297
y_pred[1] : -0.323066
y_pred[2] : -0.394514
y_pred[3] : -0.433984
loss:  0.599065083953
cur iter:  2
y_pred[0] : -0.140715
y_pred[1] : -0.181836
y_pred[2] : -0.219436
y_pred[3] : -0.238904
loss:  0.445095565699
cur iter:  3
y_pred[0] : -0.138010
y_pred[1] : -0.166091
y_pred[2] : -0.203394
y_pred[3] : -0.233627
loss:  0.428061605701
cur iter:  4
y_pred[0] : -0.139986
y_pred[1] : -0.157368
y_pred[2] : -0.195655
y_pred[3] : -0.237612
loss:  0.413581711096
cur iter:  5
y_pred[0] : -0.144410
y_pred[1] : -0.151859
y_pred[2] : -0.191676
y_pred[3] : -0.246137
loss:  0.399770442382
cur iter:  6
y_pred[0] : -0.150306
y_pred[1] : -0.147921
y_pred[2] : -0.189501
y_pred[3] : -0.257119
loss:  0.386136380384
cur iter:  7
y_pred[0] : -0.157119
y_pred[1] : -0.144659
y_pred[2] : -0.188067
y_pred[3] : -0.269322
loss:  0.372552465753
cur iter:  8
y_pred[0] : -0.164490
y_pred[1] : -0.141537
y_pred[2] : -0.186737
y_pred[3] : -0.281914
loss:  0.358993892096
cur iter:  9
y_pred[0] : -0.172187
y_pred[1] : -0.138216
y_pred[2] : -0.185125
y_pred[3] : -0.294326
loss:  0.345449256686
cur iter:  10
y_pred[0] : -0.180071
y_pred[1] : -0.134484
y_pred[2] : -0.183013
y_pred[3] : -0.306198
loss:  0.331888922037

……

cur iter:  97
y_pred[0] : -0.500351
y_pred[1] : 0.201185
y_pred[2] : 0.099026
y_pred[3] : -0.499154
loss:  3.1926009167e-06
cur iter:  98
y_pred[0] : -0.500342
y_pred[1] : 0.201122
y_pred[2] : 0.099075
y_pred[3] : -0.499190
loss:  2.88684626031e-06
cur iter:  99
y_pred[0] : -0.500331
y_pred[1] : 0.201063
y_pred[2] : 0.099122
y_pred[3] : -0.499226
loss:  2.61076360677e-06

可以看出迭代100轮,最后Loss在不断收敛,并且逐渐逼近了预期序列:y_list = [-0.5, 0.2, 0.1, -0.5]。

Reference


(喜欢的可以点一下红心,转载请注明出处,谢谢!)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 157,198评论 4 359
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 66,663评论 1 290
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 106,985评论 0 237
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,673评论 0 202
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 51,994评论 3 285
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,399评论 1 211
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,717评论 2 310
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,407评论 0 194
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,112评论 1 239
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,371评论 2 241
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,891评论 1 256
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,255评论 2 250
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,881评论 3 233
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,010评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,764评论 0 192
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,412评论 2 269
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,299评论 2 260

推荐阅读更多精彩内容