RNN(循环神经网络)训练手写数字

简介


RNN(recurrent neural network )循环(递归)神经网络主要用来处理序列数据。因为传统的神经网络从输入-隐含层-输出是全连接的,层中的神经元是没有连接的,所以对于输入数据本身具有时序性(例如输入的文本数据,每个单词之间有一定联系)的处理表现并不理想。而RNN每一个输出与前面的输出建立起关联,这样就能够很好的处理序列化的数据。
单纯循环神经网络也面临一些问题,如无法处理随着递归,权重指数级爆炸或消失的问题,难以捕捉长期时间关联。这些可以结合不同的LSTM很好的解决这个问题。
本文主要介绍简单的RNN用OC的实现,并通过训练MNIST数据来检测模型。后面有时间再介绍LSTM的实现。

公式


简单的RNN就三层,输入-隐含层-输出,如下:

将其展开的模型如下:

其中,A这个隐含层的操作就是将当前输入与前面的输出相结合,然后激活就得到当前状态信号。如下:

计算公式如下:

其中Xt是输入数据序列,St是的状态序列,V*St就是图中Ot输出,softmax运算并没有画出来。

由于RNN结构简单,反向传播的公式结合一点数理知识就可以求得,这里就不列出,详见代码实现。

数据处理


由于没找到比较好的训练数据,这里用的是前面《OC实现Softmax识别手写数字》文章里面的MNIST数据源。输入数据处理、softmax实现也都是复用的。
图片数据本质上并非是序列化的,我这里将图片的每行的的像素数据当作一个信号输入,如果一共N行,序列长度就是N。训练数据是28*28维的图片,那么就是每个信号是28*1,一共时间长度是28。

RNN实现


简单的RNN实现流程并不复杂,需要训练的参数就5个:输入的权值、神经元间转移的权值、输出的权值、以及两个转移和输出的偏置量。直接看代码:

//
//  MLRnn.m
//  LSTM
//
//  Created by Jiao Liu on 11/9/16.
//  Copyright © 2016 ChangHong. All rights reserved.
//

#import "MLRnn.h"

@implementation MLRnn

#pragma mark - Inner Method

+ (double)truncated_normal:(double)mean dev:(double)stddev
{
    double outP = 0.0;
    do {
        static int hasSpare = 0;
        static double spare;
        if (hasSpare) {
            hasSpare = 0;
            outP = mean + stddev * spare;
            continue;
        }
        
        hasSpare = 1;
        static double u,v,s;
        do {
            u = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
            v = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
            s = u * u + v * v;
        } while ((s >= 1.0) || (s == 0.0));
        s = sqrt(-2.0 * log(s) / s);
        spare = v * s;
        outP = mean + stddev * u * s;
    } while (fabsl(outP) > 2*stddev);
    return outP;
}

+ (double *)fillVector:(double)num size:(int)size
{
    double *outP = malloc(sizeof(double) * size);
    vDSP_vfillD(&num, outP, 1, size);
    return outP;
    
}

+ (double *)weight_init:(int)size
{
    double *outP = malloc(sizeof(double) * size);
    for (int i = 0; i < size; i++) {
        outP[i] = [MLRnn truncated_normal:0 dev:0.1];
    }
    return outP;
}

+ (double *)bias_init:(int)size
{
    return [MLRnn fillVector:0.1f size:size];
}

+ (double *)tanh:(double *)input size:(int)size
{
    for (int i = 0; i < size; i++) {
        double num = input[i];
        if (num > 20) {
            input[i] = 1;
        }
        else if (num < -20)
        {
            input[i] = -1;
        }
        else
        {
            input[i] = (exp(num) - exp(-num)) / (exp(num) + exp(-num));
        }
    }
    return input;
}

#pragma mark - Init

- (id)initWithNodeNum:(int)num layerSize:(int)size dataDim:(int)dim
{
    self = [super init];
    if (self) {
        _nodeNum = num;
        _layerSize = size;
        _dataDim = dim;
        [self setupNet];
    }
    return self;
}

- (id)init
{
    self = [super init];
    if (self) {
        [self setupNet];
    }
    return self;
}

- (void)setupNet
{
    _inWeight = [MLRnn weight_init:_nodeNum * _dataDim];
    _outWeight = [MLRnn weight_init:_nodeNum * _dataDim];
    _flowWeight = [MLRnn weight_init:_nodeNum * _nodeNum];
    _outBias = calloc(_dataDim, sizeof(double));
    _flowBias = calloc(_nodeNum, sizeof(double));
    _output = calloc(_layerSize * _dataDim, sizeof(double));
    _state = calloc(_layerSize * _nodeNum, sizeof(double));
}

#pragma mark - Main Method

- (double *)forwardPropagation:(double *)input
{
    _input = input;
    // clean data
    double zero = 0;
    vDSP_vfillD(&zero, _output, 1, _layerSize * _dataDim);
    vDSP_vfillD(&zero, _state, 1, _layerSize * _nodeNum);
    
    for (int i = 0; i < _layerSize; i++) {
        double *temp1 = calloc(_nodeNum, sizeof(double));
        double *temp2 = calloc(_nodeNum, sizeof(double));
        if (i == 0) {
            vDSP_mmulD(_inWeight, 1, (input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_vaddD(temp1, 1,_flowBias, 1, temp1, 1, _nodeNum);
        }
        else
        {
            vDSP_mmulD(_inWeight, 1, (input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_mmulD(_flowWeight, 1, (_state + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
            vDSP_vaddD(temp1, 1,_flowBias, 1, temp1, 1, _nodeNum);
        }
        [MLRnn tanh:temp1 size:_nodeNum];
        vDSP_vaddD((_state + i * _nodeNum), 1, temp1, 1, (_state + i * _nodeNum), 1, _nodeNum);
        vDSP_mmulD(_outWeight, 1, temp1, 1, (_output + i * _dataDim), 1, _dataDim, 1, _nodeNum);
        vDSP_vaddD((_output + i * _dataDim), 1, _outBias, 1,  (_output + i * _dataDim), 1, _dataDim);
        
        free(temp1);
        free(temp2);
    }
    
    return _output;
}

- (void)backPropagation:(double *)loss
{
    double *flowLoss = calloc(_nodeNum, sizeof(double));
    for (int i = _layerSize - 1; i >= 0 ; i--) {
        vDSP_vaddD(_outBias, 1, (loss + i * _dataDim), 1, _outBias, 1, _dataDim);
        double *transWeight = calloc(_nodeNum * _dataDim, sizeof(double));
        vDSP_mtransD(_outWeight, 1, transWeight, 1, _nodeNum, _dataDim);
        double *tanhLoss = calloc(_nodeNum, sizeof(double));
        vDSP_mmulD(transWeight, 1, (loss + i * _dataDim), 1, tanhLoss, 1, _nodeNum, 1, _dataDim);
        double *outWeightLoss = calloc(_nodeNum * _dataDim, sizeof(double));
        vDSP_mmulD((loss + i * _dataDim), 1, (_state + i * _nodeNum), 1, outWeightLoss, 1, _dataDim, _nodeNum, 1);
        vDSP_vaddD(_outWeight, 1, outWeightLoss, 1, _outWeight, 1, _nodeNum * _dataDim);
        
        double *tanhIn = calloc(_nodeNum, sizeof(double));
        vDSP_vsqD((_state + i * _nodeNum), 1, tanhIn, 1, _nodeNum);
        double *one = [MLRnn fillVector:1 size:_nodeNum];
        vDSP_vsubD(tanhIn, 1, one, 1, tanhIn, 1, _nodeNum);
        if (i != _layerSize - 1) {
            vDSP_vaddD(tanhLoss, 1, flowLoss, 1, tanhLoss, 1, _nodeNum);
        }
        vDSP_vmulD(tanhLoss, 1, tanhIn, 1, tanhLoss, 1, _nodeNum);
        
        vDSP_vaddD(_flowBias, 1, tanhLoss, 1, _flowBias, 1, _nodeNum);
        if (i != 0) {
            double *transFlow = calloc(_nodeNum * _nodeNum, sizeof(double));
            vDSP_mtransD(_flowWeight, 1, transFlow, 1, _nodeNum, _nodeNum);
            vDSP_mmulD(transFlow, 1, tanhLoss, 1, flowLoss, 1, _nodeNum, 1, _nodeNum);
            free(transFlow);
            double *flowWeightLoss = calloc(_nodeNum * _nodeNum, sizeof(double));
            vDSP_mmulD(tanhLoss, 1, (_state + (i-1) * _nodeNum), 1, flowWeightLoss, 1, _nodeNum, _nodeNum, 1);
            vDSP_vaddD(_flowWeight, 1, flowWeightLoss, 1, _flowWeight, 1, _nodeNum * _nodeNum);
            free(flowWeightLoss);
        }

        double *inWeightLoss = calloc(_nodeNum * _dataDim, sizeof(double));
        vDSP_mmulD(tanhLoss, 1, (_input + i * _dataDim), 1, inWeightLoss, 1, _nodeNum, _dataDim, 1);
        vDSP_vaddD(_inWeight, 1, inWeightLoss, 1, _inWeight, 1, _nodeNum * _dataDim);
        
        free(transWeight);
        free(tanhLoss);
        free(outWeightLoss);
        free(tanhIn);
        free(one);
        free(inWeightLoss);
    }
    free(flowLoss);
    free(loss);
}

@end

很多初始化方法以及内部函数直接是复用《OC实现(CNN)卷积神经网络》中相关的方法。

结语


我这里使用RNN,迭代2500次,每次训练100张图片,单个神经元节点个数选择50,得到的正确率94%左右。

有兴趣的朋友可以点这里看完整代码

本文参考:

  1. Understanding LSTM Networks
  2. recurrent-neural-networks-tutorial
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 160,277评论 4 364
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,777评论 1 298
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,946评论 0 245
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,271评论 0 213
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,636评论 3 288
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,767评论 1 221
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,989评论 2 315
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,733评论 0 204
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,457评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,674评论 2 249
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,155评论 1 261
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,518评论 3 258
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,160评论 3 238
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,114评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,898评论 0 198
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,822评论 2 280
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,705评论 2 273

推荐阅读更多精彩内容