BPTT推导以及基于LSTM的MNIST手写数字识别实验

BPTT (Back PropagationThough Time)公式推导

   由于RNN被广泛应用于序列标注问题(SequenceLabeling),所以这里选取该

问题作为实例来解释BPTT。下图是典型的RNN结构展开之后的结构,非常常见。

(图一)

  则将图一具体化为下图:

(图二)

基于LSTM 进行MNIST手写数字识别实验笔记

那么将基础结构构造成时序结构如下所示:

  注意,上图显示的并不是不同Block中的不同神经元,而是同一个Block中同一个神经元

在不同时刻的状态以及不同时刻之间如何传递信息。具体Block中的细节以及公式如下图所示:

代码如下:

from __future__ import print_function

import numpy as np

import tensorflow as tf

from tensorflow.contrib import rnn

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

classmnistExp:

   def__init__(self,datapath,logpath):

       self.mnist = read_data_sets(datapath, one_hot=True)

       self.logpath = logpath

       self.learning_rate = 0.001

       self.iters_num = 100000

       self.batch_size = 128

       self.print_step = 10

       self.input_size = 28

       self.timesteps = 28

       self.hidden_size = 128

       self.class_num = 10

       self.epoch_num = 30

       self.test_len = 300

   defInitGblVar(self):

       with tf.name_scope('Inputs'):

           self.x = tf.placeholder("float", [None, self.timesteps, self.input_size], name='X')

           self.istate = tf.placeholder("float", [None, 2 * self.hidden_size], name='istate')

           self.y = tf.placeholder("float", [None, self.class_num], name='Y')

       with tf.name_scope('Weights'):

           self.weights = {

               'hidden': tf.Variable(tf.random_normal([self.input_size, self.hidden_size])),

               'out': tf.Variable(tf.random_normal([self.hidden_size, self.class_num]))

           }

           tf.summary.histogram('weights_hidden',self.weights['hidden'])

           tf.summary.histogram('weights_out',self.weights['out'])

       with tf.name_scope('Biases'):

           self.biases = {

               'hidden': tf.Variable(tf.random_normal([self.hidden_size])),

               'out': tf.Variable(tf.random_normal([self.class_num]))

           }

           tf.summary.histogram('bias_hidden',self.biases['hidden'])

           tf.summary.histogram('bias_out',self.biases['out'])

       self.pred = self.LSTM(self.x, self.istate, self.weights, self.biases)

       with tf.name_scope('Cost'):

           self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.pred, labels=self.y))

           tf.summary.scalar('Cost',self.cost)

       with tf.name_scope('Train'):

           self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)

       with tf.name_scope('Prediction'):

           self.correct_pred = tf.equal(tf.argmax(self.pred, 1), tf.argmax(self.y, 1))

           self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))

           tf.summary.scalar('Accuracy',self.accuracy)

       self.init = tf.global_variables_initializer()

   defLSTM(self,_batch_x, _istate, _weights, _biases):

       with tf.name_scope('lstm_block'):

           _batch_x = tf.transpose(_batch_x, [1, 0, 2])

           _batch_x = tf.reshape(_batch_x, [-1, self.input_size])

           _batch_x = tf.matmul(_batch_x, _weights['hidden']) + _biases['hidden']

           _batch_x = tf.split(_batch_x, self.timesteps, 0)

           lstm_block = rnn.BasicLSTMCell(self.hidden_size, forget_bias=1.0)

           outputs, states = rnn.static_rnn(lstm_block, _batch_x, dtype=tf.float32)

       return tf.matmul(outputs[-1], _weights['out']) + _biases['out']

   defrun(self):

       self.InitGblVar()

       with tf.Session() as sess:

           sess.run(self.init)

           merged = tf.summary.merge_all()

           writer = tf.summary.FileWriter(self.logpath, sess.graph)

           for i in range(self.epoch_num):

               step = 1

               while step * self.batch_size < self.iters_num:

                   batch_xs, batch_ys = self.mnist.train.next_batch(self.batch_size)

                   # tf.summary.image('batch_xs',batch_xs,max_outputs=10)

                   batch_xs = batch_xs.reshape((self.batch_size, self.timesteps, self.input_size))

                   sess.run(self.optimizer, feed_dict={self.x: batch_xs, self.y: batch_ys,

                                                  self.istate: np.zeros((self.batch_size, 2 * self.hidden_size))})

                   if step % self.print_step == 0:

                       acc, loss, summary = sess.run([self.accuracy, self.cost, merged], feed_dict={self.x: batch_xs, self.y: batch_ys,

                                                           self.istate: np.zeros((self.batch_size, 2 * self.hidden_size))})

                   step += 1

               test_data = self.mnist.test.images[:self.test_len].reshape((-1, self.timesteps, self.input_size))

               test_label = self.mnist.test.labels[:self.test_len]

               print(" [*] Epoch " + str(i+1) + ": Optimization has finished, Testing Accuracy is " ,\

                   sess.run(self.accuracy, feed_dict={self.x: test_data, self.y: test_label,self.istate: \

                       np.zeros((self.test_len, 2 * self.hidden_size))}))

               summary = sess.run(merged, feed_dict={self.x: test_data, self.y: test_label})

               writer.add_summary(summary, i)

if __name__ == '__main__':

   datapath = "C:\\Users\\Administrator\\Desktop\\deep_lab\\mnist_data"

   logpath = "/tensorboard_log/tf_Ex" #tensorboard

   obj = mnistExp(datapath,logpath)

   obj.run()

实验结果:

  这里稍微介绍一下tensorboard,以方便直观学习。在训练的时候,会自动在log文

件夹中生成一个类似这样的文件后,

   不用等训练结束也可以执行下面的语句,来观察训练情况:

 训练结果:

    SCALARS:

   GRAPHS:

   DISTRIBUTIONS:

    HISTOGRAMS:

参考:

    https://www.cnblogs.com/steven-yang/p/6407445.html

    http://www.cnblogs.com/wacc/p/5341670.html

    https://en.wikipedia.org/wiki/Matrix_calculus

    https://www.cnblogs.com/zhbzz2007/p/6339346.html

    http://zhwhong.ml/2017/02/24/Backpropagation-principle/

    https://zhuanlan.zhihu.com/p/26892413

    http://www.sohu.com/a/195366563_465975

    http://colah.github.io/posts/2015-08-Understanding-LSTMs/

    http://blog.csdn.net/u010754290/article/details/47167979

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,560评论 4 361
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,104评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,297评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,869评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,275评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,563评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,833评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,543评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,245评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,512评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,011评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,359评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,006评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,062评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,825评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,590评论 2 273
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,501评论 2 268

推荐阅读更多精彩内容

  • rljs by sennchi Timeline of History Part One The Cognitiv...
    sennchi阅读 7,102评论 0 10
  • 大兵那天主动找我喝酒,看到他时眼睛肿的像咸蛋超人一样,估计昨晚又一个人躲在屋子里哭了,这是他和游訫分手的第三天,谁...
    好咧阅读 225评论 0 1
  • 本周的重点是名人名言的积累和背诵。 在我们谈话过程中是不是给出一句金句或者明人明言是能够增加说服力和感染力度的,我...
    A友仔谢杰全阅读 222评论 0 0