TensorFlow从0到1 - 16 - L2正则化对抗“过拟合”

TensorFlow从0到1系列回顾

前面的14 交叉熵损失函数——防止学习缓慢15 重新思考神经网络初始化从学习缓慢问题入手,尝试改进神经网络的学习。本篇讨论过拟合问题,并引入与之相对的L2正则化(Regularization)方法。

overfitting,来源:http://blog.algotrading101.com

无处不在的过拟合

模型对于已知数据的描述适应性过高,导致对新数据的泛化能力不佳,我们称模型对于数据过拟合(overfitting)

过拟合无处不在。

罗素的火鸡对自己的末日始料未及,曾真理般存在的牛顿力学沦为狭义相对论在低速情况下的近似,次贷危机破灭了美国买房只涨不跌的神话,血战钢锯岭的医疗兵Desmond也并不是懦夫。

凡是基于经验的学习,都存在过拟合的风险。动物、人、机器都不能幸免。

Russell's Turkey,来源:http://chaospet.com/115-russells-turkey/g

谁存在过拟合?

对于一些离散的二维空间中的样本点,下面两条曲线谁存在过拟合?

谁存在过拟合?

遵循奥卡姆剃刀的一派,主张“如无必要,勿增实体”。他们相信相对简单的模型泛化能力更好:上图中的蓝色直线,虽然只有很少的样本点直接落在它上面,但是不妨认为这些样本点或多或少包含一些噪声。基于这种认知,可以预测新样本也会在这条直线附近出现。

或许很多时候,倾向简单会占上风,但是真实世界的复杂性深不可测。虽然在自然科学中,奥卡姆剃刀被作为启发性技巧来使用,帮助科学家发展理论模型工具,但是它并没有被当做逻辑上不可辩驳的定理或者科学结论。总有简单模型表达不了,只能通过复杂模型来描述的事物存在。很有可能红色的曲线才是对客观世界的真实反映。

康德为了对抗奥卡姆剃刀产生的影响,创建了他自己的反剃刀:“存在的多样性不应被粗暴地忽视”。

阿尔伯特·爱因斯坦告诫:“科学理论应该尽可能简单,但不能过于简单。”

所以仅从上图来判断,一个理性的回答是:不知道。即使是如此简单的二维空间情况下,在没有更多的新样本数据做出验证之前,不能仅通过模型形式的简单或复杂来判定谁存在过拟合。

过拟合的判断

二维、三维的模型,本身可以很容易的绘制出来,当新的样本出现后,通过观察即可大致判断模型是否存在过拟合。

然而现实情况要复杂的多。对MNIST数字识别所采用的3层感知器——输入层784个神经元,隐藏层30个神经元,输出层10个神经元,包含23860个参数(23860 = 784 x 30 + 30 x 10 + 30 + 10),靠绘制模型来观察是不现实的。

最有效的方式是通过识别精度判断模型是否存在过拟合:比较模型对验证集和训练集的识别精度,如果验证集识别精度大幅低于训练集,则可以判断模型存在过拟合

至于为什么是验证集而不是测试集,请复习11 74行Python实现手写体数字识别中“验证集与超参数”一节。

然而静态的比较已训练模型对两个集合的识别精度无法回答一个问题:过拟合是什么时候发生的

要获得这个信息,就需要在模型训练过程中动态的监测每次迭代(Epoch)后训练集和验证集的识别精度,一旦出现训练集识别率继续上升而验证集识别率不再提高,就说明过拟合发生了

这种方法还会带来一个额外的收获:确定作为超参数之一的迭代数(Epoch Number)的量级。更进一步,甚至可以不设置固定的迭代次数,以过拟合为信号,一旦发生就提前停止(early stopping)训练,避免后续无效的迭代。

过拟合监测

了解了过拟合的概念以及监测方法,就可以开始分析我们训练MNIST数字识别模型是否存在过拟合了。

所用代码:tf_16_mnist_loss_weight.py。它在12 TensorFlow构建3层NN玩转MNIST代码的基础上,使用了交叉熵损失,以及1/sqrt(nin)权重初始化:

  • 1个隐藏层,包含30个神经元;
  • 学习率:3.0;
  • 迭代数:30次;
  • mini batch:10;

训练过程中,分别对训练集和验证集的识别精度进行了跟踪,如下图所示,其中红线代表训练集识别率,蓝线代表测试集识别率。图中显示,大约在第15次迭代前后,测试集的识别精度稳定在95.5%不再提高,而训练集的识别精度仍然继续上升,直到30次迭代全部结束后达到了98.5%,两者相差3%。

由此可见,模型存在明显的过拟合的特征。

训练集和验证集识别精度(基于TensorBoard绘制)

过拟合的对策:L2正则化

对抗过拟合最有效的方法就是增加训练数据的完备性,但它昂贵且有限。另一种思路是减小网络的规模,但它可能会因为限制了模型的表达潜力而导致识别精度整体下降。

本篇引入L2正则化(Regularization),可以在原有的训练数据,以及网络架构不缩减的情况下,有效避免过拟合。L2正则化即在损失函数C的表达式上追加L2正则化项

L2正则化

上式中的C0代表原损失函数,可以替换成均方误差、交叉熵等任何一种损失函数表达式。

关于L2正则化项的几点说明:

  • 求和∑是对网络中的所有权重进行的;
  • λ(lambda)为自定义参数(超参数);
  • n是训练样本的数量(注意不是所有权重的数量!);
  • L2正则化并没有偏置参与;

该如何理解正则化呢?

对于使网络达到最小损失的权重w,很可能有非常多不同分布的解:有的均值偏大、有的偏小,有的分布均匀,有的稀疏。那么在这个w的解空间里,该如何挑选相对更好的呢?正则化通过添加约束的方式,帮我们找到一个方向。

L2正则化表达式暗示着一种倾向:训练尽可能的小的权重,较大的权重需要保证能显著降低原有损失C0才能保留

至于正则化为何能有效的缓解过拟合,这方面数学解释其实并不充分,更多是基于经验的认知。

L2正则化的实现

因为在原有损失函数中追加了L2正则化项,那么是不是得修改现有反向传播算法(BP1中有用到C的表达式)?答案是不需要。

C对w求偏导数,可以拆分成原有C0对w求偏导,以及L2正则项对w求偏导。前者继续利用原有的反向传播计算方法,而后者可以直接计算得到:

C对于偏置b求偏导保持不变:

基于上述,就可以得到权重w和偏置b的更新方法:

TensorFlow实现L2正则化

TensorFlow的最优化方法tf.train.GradientDescentOptimizer包办了梯度下降、反向传播,所以基于TensorFlow实现L2正则化,并不能按照上节的算法直接干预权重的更新,而要使用TensorFlow方式:

tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_2)
tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_3)
regularizer = tf.contrib.layers.l2_regularizer(scale=5.0/50000)
reg_term = tf.contrib.layers.apply_regularization(regularizer)

loss = (tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3)) +
    reg_term)

对上述代码的一些说明:

  • 将网络中所有层中的权重,依次通过tf.add_to_collectio加入到tf.GraphKeys.WEIGHTS中;
  • 调用tf.contrib.layers.l2_regularizer生成L2正则化方法,注意所传参数scale=λ/n(n为训练样本的数量);
  • 调用tf.contrib.layers.apply_regularization来生成损失函数的L2正则化项reg_term,所传第一个参数为上面生成的正则化方法,第二个参数为none时默认值为tf.GraphKeys.WEIGHTS
  • 最后将L2正则化reg_term项追加到损失函数表达式;

向原有损失函数追加L2正则化项,模型和训练设置略作调整:

  • 1个隐藏层,包含100个神经元;
  • 学习率:0.5;
  • 迭代数:30次;
  • mini batch:10;

重新运行训练,跟踪训练集和验证集的识别精度,如下图所示。图中显示,在整个30次迭代中,训练集和验证集的识别率均持续上升(都超过95%),最终两者的差距控制在0.5%,过拟合程度显著的减轻了。

需要注意的是,尽管正则化有效降低了验证集上过拟合程度,但是也降低了训练集的识别精度。所以在实现L2正则化时增加了隐藏层的神经元数量(从30到100)来抵消识别精度的下降。

L2正则化(基于TensorBoard绘制)

附完整代码

import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

FLAGS = None


def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,
                                      validation_size=10000)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    W_2 = tf.Variable(tf.random_normal([784, 100]) / tf.sqrt(784.0))
    '''W_2 = tf.get_variable(
        name="W_2",
        regularizer=regularizer,
        initializer=tf.random_normal([784, 30], stddev=1 / tf.sqrt(784.0)))'''
    b_2 = tf.Variable(tf.random_normal([100]))
    z_2 = tf.matmul(x, W_2) + b_2
    a_2 = tf.sigmoid(z_2)

    W_3 = tf.Variable(tf.random_normal([100, 10]) / tf.sqrt(100.0))
    '''W_3 = tf.get_variable(
        name="W_3",
        regularizer=regularizer,
        initializer=tf.random_normal([30, 10], stddev=1 / tf.sqrt(30.0)))'''
    b_3 = tf.Variable(tf.random_normal([10]))
    z_3 = tf.matmul(a_2, W_3) + b_3
    a_3 = tf.sigmoid(z_3)

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_2)
    tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_3)
    regularizer = tf.contrib.layers.l2_regularizer(scale=5.0 / 50000)
    reg_term = tf.contrib.layers.apply_regularization(regularizer)

    loss = (tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3)) +
        reg_term)

    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    scalar_accuracy = tf.summary.scalar('accuracy', accuracy)
    train_writer = tf.summary.FileWriter(
        'MNIST/logs/tf16_reg/train', sess.graph)
    validation_writer = tf.summary.FileWriter(
        'MNIST/logs/tf16_reg/validation')

    # Train
    best = 0
    for epoch in range(30):
        for _ in range(5000):
            batch_xs, batch_ys = mnist.train.next_batch(10)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        accuracy_currut_train = sess.run(accuracy,
                                         feed_dict={x: mnist.train.images,
                                                    y_: mnist.train.labels})

        accuracy_currut_validation = sess.run(
            accuracy,
            feed_dict={x: mnist.validation.images,
                       y_: mnist.validation.labels})

        sum_accuracy_train = sess.run(
            scalar_accuracy,
            feed_dict={x: mnist.train.images,
                       y_: mnist.train.labels})

        sum_accuracy_validation = sess.run(
            scalar_accuracy,
            feed_dict={x: mnist.validation.images,
                       y_: mnist.validation.labels})

        train_writer.add_summary(sum_accuracy_train, epoch)
        validation_writer.add_summary(sum_accuracy_validation, epoch)

        print("Epoch %s: train: %s validation: %s"
              % (epoch, accuracy_currut_train, accuracy_currut_validation))
        best = (best, accuracy_currut_validation)[
            best <= accuracy_currut_validation]

    # Test trained model
    print("best: %s" % best)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='../MNIST/',
                        help='Directory for storing input data')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

下载 tf_16_mnist_loss_weight_reg.py

上一篇 15 重新思考神经网络初始化
下一篇 17 Step By Step上手TensorBoard


共享协议:署名-非商业性使用-禁止演绎(CC BY-NC-ND 3.0 CN)
转载请注明:作者黑猿大叔(简书)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,233评论 4 360
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,013评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,030评论 0 241
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,827评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,221评论 3 286
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,542评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,814评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,513评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,225评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,497评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,998评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,342评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,986评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,055评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,812评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,560评论 2 271
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,461评论 2 266

推荐阅读更多精彩内容