TensorFlow实战(三)分类应用入门:MNIST手写数字识别

教程传送门:深度学习应用开发-TensorFlow实践 第七讲
完整笔记

一、问题描述

MNIST数据集由来自250个不同人手写的数字构成,包括55000个训练集,5000个验证集和10000个测试集。

二、读取数据集

1. 数据集解读

    #TensorFlow提供了数据集读取方法
    import tensorflow as tf
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    #下载MNIST数据集到指定目录下(指定独热编码)
    mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
    
    #查看数量
    print('训练集 train 数量:', mnist.train.num_examples, 
          ',验证集 validation 数量:',  mnist.validation.num_examples, 
          ',测试集 test 数量:', mnist.test.num_examples)
    # 训练集 train 数量: 55000 ,验证集 validation 数量: 5000 ,测试集 test 数量: 10000
    
    #查看shape
    print('train images shape:', mnist.train.images.shape,
         'labels shape:',mnist.train.labels.shape)
    # train images shape: (55000, 784) labels shape: (55000, 10)   
    #图像是 28x28=784, 标签共 10 类,0-9
    
    #具体看一副image的数据
    mnist.train.images[0].shape
    # (784,)
    
    #image数据再塑形
    mnist.train.images[0].reshape(28, 28) #按行优先,逐行排列
  • 独热编码(one hot encoding)

一种稀疏向量,其中一个元素设为1,所有其他元素均为0。常用于表示拥有有限个可能值的字符串或标识符。

    #下载MNIST数据集到指定目录下(指定独热编码)
    mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
    
    mnist.train.labels[1]
    # array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
    #表示数字3
    
    #独热编码取便签值:使用 argmax
    import numpy as np
    np.argmax(mnist.train.labels[1]) #argmax返回的是最大数的索引
    # 3

1.使用one hot编码,将离散特征的取值拓展到了欧式空间,离散特征的某个取值就对应欧式空间的某个点。
2.机器学习算法中,特征之间距离的计算或相似度的常用计算方法是基于欧式空间的。
(如若不用独热编码,按距离计算,说1比8更相似于3就不合理。)
3.将离散型特征使用one-hot编码,会让特征之间的距离计算更加合理。

  • 非one-hot编码的标签值
    mnist_no_one_hot = input_data.read_data_sets("MNIST_data/", one_hot = False)
    print(mnist_no_one_hot.train.labels[0:10])
    # [7 3 4 6 1 8 1 0 9 8] 
    #one_hot = False,直接返回标签对应的数值

2. 可视化image

    #定义可视化image的方法
    import matplotlib.pyplot as plt
    def plot_image(image):
        plt.imshow(image.reshape(28, 28), camp = 'binary')
        plt.show()
        #可视化images[1]
    plot_image(mnist.train.images[1])
训练集中的第二张手写字

3. 数据集划分

通过将数据集划分为三个子集,可以大幅降低过拟合的发生几率:
使用验证集评估训练集的效果。
在模型“通过”验证集后,使用测试集再次检查评估结果。

4. 数据的批量读取

    #传统方法,前10条
    print(mnist.train.labels[0:10])  
    
    #内部会对数据集先做shuffle打乱数据
    batch_images_xs, batch_labels_ys = mnist.train.next_batch(batch_size = 10)
    print(batch_labels_ys)

三、分类模型构建与训练

1. 模型构建

  • 定义待输入数据的占位符
    #mnist 中每张图片共有28*28=784个像素点
    x = tf.placeholder(tf.float32, [None, 784], name = "X")
    #0-9一共10个数字 => 10个类别
    y = tf.placeholder(tf.float32, [None, 10], name = "Y")
  • 定义模型变量

在本案例中,以正态分布的随机数初始化权重W,以常数0初始化偏置b

    #定义变量
    W = tf.Variable(tf.random_normal([784, 10]), name = "W")
    b = tf.Variable(tf.zeros([10]), name = "b")
  • 了解一下tf.random_normal()

从“服从指定正态分布的序列”中随机取出指定个数的值

    norm = tf.random_normal([100]) #生成100个随机数
    with tf.Session() as sess:
        norm_data = norm.eval()
    print(norm_data[:10]) #打印前10个随机数
    # [ 0.72286093  0.1870783   0.00341146  0.41772947  0.66194445  0.08350101
     -3.12047291 -0.87533593  0.36186498  0.94298702]
     
    import matplotlib.pyplot as plt
    plt.hist(norm_data) #绘制直方图
    plt.show()
image
  • 定义前向计算
    forward = tf.matmul(x, W) + b #前向计算
image
  • 结果分类
    pred = tf.nn.softmax(forward) #Softmax分类
image

2. 逻辑回归

许多问题的预测结果是一个数值,比如房价预测问题,可以用线性模型来描述:

    Y=x_1*w_1+x_2*w_2+...+x_n*w_n+b

但也有很多场景需要输出的是概率估算值,例如:
(1)根据邮件内容判断是垃圾邮件的可能性
(2)根据医学影像判断肿瘤是恶性的可能性
(3)手写数字分别是0、1、2、3...8、9的可能性(概率)
这时需要将预测输出值控制在[0,1]区间内。
二元分类问题的目标是正确预测两个标签中的一个,逻辑回归可以用于处理这类问题。

  • Sigmod函数

Sigmod函数(S型函数)生成的输出值正好具有这些特性,定义域为全体实数,值域在[0,1]之间,Z值在0点对应的结果为0.5,且sigmod函数连续可微分。

    y=\frac{1}{1+e^{-z}}
    z=x_1*w_1+x_2*w_2+...+x_n*w_n+b
image
  • 逻辑回归中的损失函数

线性回归的损失函数是平方损失,而若逻辑回归也用平方损失,将Sigmod函数代入平方损失函数,将有多个极小值,如果采用梯度下降法,会容易导致陷入局部最优解

image

二元逻辑回归的损失函数一般采用对数损失函数,其中(x,y)∈D是有标签样本(x,y)的数据集,y是有标签样本中的标签,取值必须是0或1,y'是对于特征集x的预测值(介于0和1之间)

    J(W,b)= \sum_{(x,y)∈D}-ylog(y')-(1-y)log(1-y')
image

3. 多元分类和Softmax

  • Softmax思想

在多分类问题中,Softmax会为每个类别分配一个用小数表示的概率。这些用小数表示的概率相加之和必须为1.0.

image
  • Softmax方程式
    p_i = \frac{e^{y_i}}{\sum_{k=1}^{C}e^{y_k}}
  • 交叉熵损失函数

交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的。给定两个概率分布p和q,通过q来表示p的交叉熵为

    H(p,q) = - \sum_xp(x)logq(x)

交叉熵刻画的是两个概率分布之间的距离,p代表正确答案,q代表的是预测值,交叉熵越小,两个概率的分布约接近,损失越低。

交叉熵损失函数计算案例

  • 定义交叉熵损失函数
    yi为标签值,yi'为预测值,公式如下:
    Loss = - \sum_{i=1}^ny_ilogy_i' 
    #定义交叉熵损失函数
    loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
                                                 reduction_indices=1)) 

4. 模型构建与训练实践

  • 设置训练参数
    train_epochs = 50 #训练轮数
    batch_size = 100 #单次训练样本数(批次大小)
    total_batch = int(mnist.train.num_examples/batch_size) #一轮训练有多少批次
    display_step = 1 #显示粒度
    learning_rate = 0.01 #学习率
  • 选择优化器
    #梯度下降优化器
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
  • argmax详解
    import tensorflow as tf
    import numpy as np
    arr1 = np.array([1,3,2,5,7,0])
    arr2 = np.array([[1.0,2,3],[3,2,1],[4,7,2],[8,3,2]])
    print("arr1=", arr1)
    print("arr2=\n", arr2)
    # arr1= [1 3 2 5 7 0]
    # arr2=
    # [[ 1.  2.  3.]
    # [ 3.  2.  1.]
    # [ 4.  7.  2.]
    # [ 8.  3.  2.]]
    
    #返回最大值的下标
    argmax_1 = tf.argmax(arr1)
    #指定第二个参数为0,按第一维(行)的元素取值,即同列的每一行最大值的下标
    argmax_20 = tf.argmax(arr2, 0) #指定第二个参数为1,按第二维(列)的元素取值,即同行的每一列最大值的下标
    argmax_21 = tf.argmax(arr2, 1)
    #指定第二个参数为-1,则第最后维的元素取值
    argmax_22 = tf.argmax(arr2, -1) 
    with tf.Session() as sess:
        print(argmax_1.eval())
        print(argmax_20.eval())
        print(argmax_21.eval())
        print(argmax_22.eval())
    # 4
    # [3 2 0]
    # [2 0 1 0]
    # [2 0 1 0]
  • 定义准确率
    #检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况,相等为1,不等为0,实际要转浮点数
    correct_prediction = tf.equal(tf.argmax(pred, 1),tf.argmax(y, 1))
    #准确率,将布尔值转为浮点数,并计算平均值
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  • 声明会话,初始化变量
    sess = tf.Session() #声明会话
    init = tf.global_variables_initializer() #变量初始化
    sess.run(init)
  • 训练模型
    #开始训练
    for epoch in range(train_epochs):
        for batch in range(total_batch):
            xs, ys = mnist.train.next_batch(batch_size) #读取批次数据
            sess.run(optimizer, feed_dict = {x: xs, y: ys}) #执行批次训练
        #total_batch个批次训练完成后,用验证数据计算误差与准确率,验证集没有分批
        loss,acc = sess.run([loss_function,accuracy],
                           feed_dict= {x: mnist.validation.images, y: mnist.validation.labels})
        #打印训练过程中的详细信息
        if(epoch+1)%display_step == 0:
            print("Train Epoch:",'%02d'%(epoch+1),"Loss=","{:.9f}".format(loss),\
                 "Accuracy=","{:.4f}".format(acc))
    print("Train Finished!")
    
    # Train Epoch: 01 Loss= 5.441855431 Accuracy= 0.3148
    # ...(Loss趋于更小,准确率越来越高,其余轮次略)
    # Train Epoch: 50 Loss= 0.655943811 Accuracy= 0.8618
    # Train Finished!
  • 评估模型

完成训练后,在10000条测试集上评估模型的准确率

    accu_test = sess.run(accuracy,
                        feed_dict = {x: mnist.test.images, y: mnist.test.labels})
    print("Test Accuracy:",accu_test)
    # Test Accuracy: 0.8619

5. 模型应用与可视化

  • 应用模型
    #由于pred预测结果是one-hot编码格式,所以需要转换为0-9数字
    prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x: mnist.test.images})
    #查看预测结果中的前10项
    prediction_result[0:10]
    
    # array([7, 6, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
  • 定义可视化函数
    #定义可视化函数
    import matplotlib.pyplot as plt
    import numpy as np
    def plot_images_labels_prediction(images, #图像列表
                                     labels, #标签列表
                                     prediction, #预测值列表
                                     index, #从第index个开始显示
                                     num = 10): #缺省一次显示10幅
        fig = plt.gcf() #获取当前图表,Get Current Figure
        fig.set_size_inches(10, 12) #宽10英寸高12英寸,1英寸等于2.54cm
        if num > 25:
            num = 25 #最多显示25个子图
        for i in range(0, num):
            ax = plt.subplot(5,5, i+1) #获取当前要处理的子图,i+1 从第1个子图开始到num+1个子图
            ax.imshow(np.reshape(imges[index],(28, 28)),
                     cmap = 'binary') #显示第index个图像
            title = "label=" + str(np.argmax(labels[index])) #构建该图上要显示的title
            if len(prediction)>0:
                title += ",predict=" + str(prediction[index])
                ax.set_title(title,fontsize = 10) #显示图上的title信息
                ax.set_xticks([]); #不显示坐标轴
                ax.set_yticks([])
                index += 1
        plt.show()
  • 可视化预测结果
    plot_images_labels_prediction(mnist.test.images,
                                 mnist.test.labels,
                                 prediction_result,10,10)
预测结果
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 162,475评论 4 372
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 68,744评论 2 307
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 112,101评论 0 254
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,732评论 0 221
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 53,141评论 3 297
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 41,049评论 1 226
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 32,188评论 2 320
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,965评论 0 213
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,716评论 1 250
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,867评论 2 254
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,341评论 1 265
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,663评论 3 263
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,376评论 3 244
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,200评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,990评论 0 201
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 36,179评论 2 285
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,979评论 2 279