用tensorflow.js实现浏览器内的手写数字识别

简介

Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预测放在前端,置于浏览器内。

本文会用一个简单的demo介绍如何从零开始训练一个tensorflow模型,并在浏览器内实现手写数字识别,最终效果大约如下:


手写数字识别示例

本文会假设你有基本的python和JavaScript的知识。项目的完整代码参考github

准备

项目代码的目录结构如下:


项目目录结构

整个结构大概分成server和web两个部分,分别是服务端和浏览器端的代码。

我们的流程大概如下:

  1. 下载训练数据集,用python的tensorflow训练模型,并保存模型文件。
  2. 使用python的flask启动服务,使模型文件可以作为本地服务的静态文件被访问。
  3. 在网页html内,用canvas创建一个可以随意涂抹的画布,并能够获取画布上的像素信息。
  4. 在JavaScript脚本内导入tf.js,载入训练模型,通过模型计算画布上的信息的预测结果,并显示在图表上。

我们需要的所有依赖如下:

python:

建议使用3.5以上的版本。我不能保证在<3.5的版本中它是否能正常工作。Tensorflow的兼容性问题一向令人头疼。注意在mac和linux上默认的python是python2。

  • numpy —— 一个知名的python数学计算库,在矩阵和数组运算方面非常强大
  • tensorflow —— 机器学习库,直接用pip安装的是cpu版本。如果你的pc有一个足够好的独立显卡,可以试试tensorflow-gpu。它可以使训练的速度更快。但tensorflow-gpu的配置方法比较复杂。我们的模型比较简单,即使用cpu训练也不会耗时太久。
  • tensorflowjs —— 用于导出并保存可以被浏览器使用的模型文件
  • flask —— 一个轻量级的python网络服务框架
  • flask-cors —— 用于支持flask跨域请求的一个库

这个demo内已经包含一个已经训练好的模型,所以你如果并不想自己再训练一次,可以不安装tensorflow和tensorflowjs。所有这些依赖都可以通过pip安装。

JavaScript:

你不需要特别安装任何东西,因为我们的库都是通过链接导入的。

  • tf.js —— 它就是本文要介绍的,尽管只会涉及它的极小的一点。
  • fabric.js —— 可选,用于比较方便地构造画布。
  • Chart.js —— 可选,只是用来画出下边的图表的。你也可以不要它,如果你对这种可视化的结果不感兴趣。
浏览器:

反正在chrome浏览器里是能跑起来的……

训练

项目文件里面已经包含了一个训练好的模型,位于{项目路径}/server/models/mnist文件夹内。

我们使用MNIST数据集来训练模型。MNIST是一个知名的手写数字识别的数据集。对很多机器学习的初学者而言,这很可能是他们接触到的第一个数据集。这个数据集中包含60000张训练图片以及10000张测试图片,每张图片都是一个28×28像素的手写数字图片。如下图所示:

mnist.png

MNIST用一个28×28的矩阵来代表这样的一张数字图片,矩阵内的每个元素表示对应点位置的灰度,在0~255之间。

下载数据:

事实上,你可以跳过下载数据这一步而直接开始训练,因为在训练函数中会自动下载数据,但鉴于国内糟糕的网络环境,我还是建议你先把数据手动下载下来。我会优先从本地读取数据。

下载地址:mnist.npz
下载完成后保存在路径{项目路径}/server/datasets/mnist.npz的位置。npz是numpy的一种数据压缩格式。文件大小大概11m。然后我们用load_data函数载入数据:

import numpy as np
from tensorflow.keras import layers, datasets

def load_data(path):
    try:
        with np.load(path) as f:
            x_train, y_train = f['x_train'], f['y_train']
            x_test, y_test = f['x_test'], f['y_test']
            x_train, x_test = x_train/255.0, x_test/255.0
            return (x_train, y_train), (x_test, y_test)
    except FileNotFoundError:
        return datasets.mnist.load_data()

其中,x_train是一个60000×28×28的3维向量,代表60000张图片;y_train是长度60000的向量,每一项代表对应图片的实际数字,是一个0~9的整数。x_test,y_test是测试集上的对应数据,测试集大小为10000。注意x_train, x_test = x_train/255.0, x_test/255.0这一步是把每个灰度数字转换为一个0~1之间的小数。

训练模型:

我们使用tensorflow.keras的接口来实现一个简单的卷积神经网络(Convolutional Neural Network, CNN)模型。它包含了一个卷积层,一个池化层,和两个全连接层。我不会这里解释全部概念——对新手来说,它们过于令人困惑而费解。而且你也不需要在这里理解它。如果你真的很想从直观上把握它的话,你可以试试这篇博客:An Intuitive Explanation of Convolutional Neural Networks。它有点长,但为此花一些时间依然是值得的。

在server/train.py文件下可以看到训练函数的代码:

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, datasets
import tensorflowjs as tfjs

def train_modle(data):
    (x_train, y_train), (x_test, y_test) = data
    model = Sequential([
        layers.Reshape((28, 28, 1), input_shape=(28, 28)),
        layers.Conv2D(16, (5, 5), padding='valid', input_shape=(28, 28, 1), activation='relu'),
        layers.MaxPooling2D(pool_size=2),
        layers.Dropout(0.2),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(x_train, y_train, epochs=5, batch_size=64)
    model.evaluate(x_test, y_test)
    tfjs.converters.save_keras_model(model, model_path)

从上至下,这个训练函数做的:

  • 读取从npz中或自动下载的训练数据。
  • 构造模型的层序列,tensorflow.keras是以layer这种对象组织计算过程的。每一层输出是下一层的输入,最后的输出就是模型的输出。它每层依次是:
    • Reshape。注意我们的输入的数据对应的是图片的每个像素的灰度,是没有‘深度’的。而卷积层要求的输入必须是有‘深度’的。所以我们首先为数据额外地加一个‘深度’为1的第三个维度。
    • 卷积层。这一层的作用是提取一个图片的每一个点周围的‘局部特征’,并传递给下一层。我们需要16种特征,对每一种特征,我们用一个5×5大小的矩阵,‘扫描’图片,并据此计算出一个值。所以这一步,是把每个点都映射到一个长16的向量,来代表这个点的16种不同的局部特征。
    • 池化层。这一步是为了降低数据的大小。在所有的相邻的2×2的的范围内,我们只保留其中的最大值。
    • Dropout。在训练过程中,每次更新参数时,随机地把一部分输入节点忽略掉。这是一种防止过拟合的简单技巧。它只会应用于训练时,不会用在预测上。
    • Flatten。输入展平成一个一维数组。如果你的下一层是全连接层,那么这一步是必要的(除非你想把输出格式搞得一团糟)。
    • Dense。大小为128的全连接层。上一层的所有点都与这层的所有点相连。
    • Dense。大小为10的最末端的全连接层,它的输出就是模型的预测结果,对应一张图片是每个数字的概率。
  • 损失函数是用于估计模型预测结果和正确结果的偏差的函数。比如实际的数字为2,那么我们期待的结果应该是[ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ],即除第3位为1外其他的都是0;而我们的预测结果可能是[0.2, 0.3, 0.1,...]。我们这里使用交叉熵算法来评估两种概率分布间的差别。训练的目的就是使得这样的损失函数的值尽量接近0。
  • 优化函数决定了在预测结果和正确结果的特定偏差下,应该如何更新参数。这里我们使用adam优化器。
  • fit。使用训练集来训练。我们每次在大小60000数据中取出64个作为一批,计算损失函数并优化参数。在整个数据集上,重复5次。
  • evaluate。使用测试集来评估训练结果。只计算损失函数,不做参数优化。
  • 最后一步,是把模型的训练结果保存成文件,在预测时可以调用。

撇开数学上的概念理解不谈,一般初学者在训练过程中最容易让人弄错的地方是数据的格式(shape)。

运行文件,开始训练

python server/train.py

如果你的环境配置正确,你应该会看到这样的输出:

60000/60000 [==============================] - 11s 185us/sample - loss: 0.1896 - acc: 0.9453
Epoch 2/5
60000/60000 [==============================] - 14s 225us/sample - loss: 0.0678 - acc: 0.9791
Epoch 3/5
60000/60000 [==============================] - 13s 221us/sample - loss: 0.0504 - acc: 0.9840
Epoch 4/5
60000/60000 [==============================] - 14s 233us/sample - loss: 0.0377 - acc: 0.9881
Epoch 5/5
60000/60000 [==============================] - 14s 231us/sample - loss: 0.0301 - acc: 0.9900
10000/10000 [==============================] - 1s 93us/sample - loss: 0.0360 - acc: 0.9879

在我的i5 cpu电脑上,整个训练过程大约耗时不到1分钟。
这个输出的结果显示了每一个epoch的耗时、损失函数的值和准确率。最后一行是在测试集上的结果。可以看到,我们的训练结果在测试集上有 98.79% 的准确率。同时,在{项目路径}/server/models/mnist内的文件也会被覆盖更新。你可以调整模型的结构和条件,多试几次,来评估不同条件下的训练结果。

{项目路径}/server/models/mnist下有两个文件,一个很小的model.json文件和另一个大小约1m以上的.bin文件。model.json文件可以直接打开,里面包括了模型的一些总体信息,如模型的结构和参数文件的位置,也就是.bin文件,这个文件记录了这个模型训练出来的所有参数。另外,如果你已经改动过模型的结构或者其他条件重新训练,那么这样的参数文件可能不止一个。

服务

我们已经训练好了模型,但这个模型文件是不能直接被浏览器载入使用的,因为现代浏览器一般都会阻止js直接读取本地文件内容。并且在设计上,这个模型文件也应该是保存在服务端而不是客户端。
我们需要做的,是启动一个服务,并使得这个模型成为这个服务的静态资源,这样js就可以通过请求拉取文件内容。
在server目录下的main.py文件:

from flask import Flask
from flask_cors import CORS

app = Flask(__name__,
            static_url_path='/models', 
            static_folder='models')

cors = CORS(app) 

@app.route("/")
def hello():
    return "Hello World!"

if __name__ == '__main__':
    app.run(debug=True)

这是一个非常简单的flask应用代码。在这个应用中,我们把url路径/models映射到了文件目录models(相对于本文件),外界通过{host}/models就能访问到models内的文件。
在项目的根目录下,用命令行启动这个文件:

python server/main.py

如果一切正常的话,你应该会看到这样的输出

* Serving Flask app "main" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment
Use a production WSGI server instead.
* Debug mode: on
* Restarting with stat
* Debugger is active!
* Debugger PIN: 267-971-636
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)

现在,你可以打开 http://127.0.0.1:5000/ 或者 http://localhost:5000/,如果你在屏幕上看到了“Hello World!”,就说明服务已经启动成功了。ctrl+c可以退出服务。
此时如果你打开http://localhost:5000/models/mnist/model.json就可以看到我们之前训练出来的模型的json文件。
另外,注意在代码中,我们还加了一句cors = CORS(app),这是为了让这个服务接受跨域请求。本文在这里不会展开讨论这个问题,简单地说:如果在js脚本中试图请求拉取的后端资源的协议或域名与js本身的不一致,那么浏览器会阻止这个请求——这是一种安全保护策略,除非你加了这行代码让后端资源接受跨域。

预测

我们在html头引入这几个库:

    <head>
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
        <script src='http://cdnjs.cloudflare.com/ajax/libs/fabric.js/1.4.0/fabric.min.js'></script>
        <script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.4.0/Chart.min.js"></script>
        <script src='./model.js'></script>
    </head>

最后一个model.js是本地的js文件,我们会把模型的导入和数据预测函数都封装在这里。

在model.js文件里,导入模型:

const MODEL_URL = 'http://localhost:5000/models/mnist/model.json' 

var loadModel = (async function() {
    window.model = await tf.loadLayersModel(MODEL_URL);
    console.log('load model')
    return model;
})
loadModel();

MODEL_URL 就是模型的model.json文件的url地址。用tf.loadLayersModel函数来载入模型并绑定在window上。当你在浏览器控制台里看到 'load model',模型就载入成功了。注意tf.loadLayersModel是异步函数,它返回的是一个Promise对象,你需要用await或者.then()的回调式方法来获取载入的模型对象。

在html里,添加一个id="canvas"的canvas和两个按钮,一个用于识别,另一个用于清空canvas。

        <div id='container'>
            <canvas width="140" height="140" id="canvas" class="canvas"></canvas>
            <div class='button-container'>
                <button onclick="recognize()">recognize</button>
                <button onclick="clear_canvas()">clear</button>   
            </div>
        </div>

在html内<script>内加上:

        var fabric_canvas = new fabric.Canvas('canvas', {backgroundColor: "#000000"});
        fabric_canvas.renderTop();
        fabric_canvas.isDrawingMode = true;
        fabric_canvas.freeDrawingBrush.width = 12;
        fabric_canvas.freeDrawingBrush.color = "#ffffff";

        var recognize = async function() {
            var results = await predict('canvas');
            console.log(results);
        }

        var clear_canvas = function() {
            fabric_canvas.clear();
        }

我们使用fabric.js来构造可以任意涂抹的画图。这并不是必要的,只是可以少写一点代码。

在识别图片recognize函数内,我们调用了一个predict函数,并传入了canvas的id。我们希望这个函数返回的结果就是预测结果。

const width = 28;
const height = 28;

var predict = async function(id) {
    var model = window.model;
    var canvas = document.getElementById(id);
    var example = this.load_img(canvas);
    var prediction = await model.predict(example).data();
    var results = Array.from(prediction);
    return results
}

var load_img = function(img) {
    var tensor = tf.browser.fromPixels(img)
        .resizeNearestNeighbor([width, height])
        .mean(2)
        .expandDims()
        .toFloat()
        .div(255.0)
    return tensor;
};

predict函数的逻辑也相当直接了当:

  • 获取canvas对象
  • 调用load_img函数,从图片得到tensor张量对象
  • 调用已经载入的模型,预测结果。这里值得注意的是模型的预测函数model.predict同样是一个异步函数,需toFloat要在它的回调的.data方法中取出预测结果。这个结果默认是Float32Array类型,可以转换为Array。

与之前训练模型类似,最麻烦的地方是数据格式shape的处理,我们在load_img里有这么几步:

  • 用tf.browser.fromPixels方法读取canvas的像素信息。这个方法同样可以读取图片的信息。返回的结果是一个张量tensor,shape是140×140×3。最后一维是这张图片的每个像素在3个颜色通道上的值,每个值是一个0~255之间的整数。
  • 把140×140的数据resize成一个28×28的数据。因为我们训练的模型只接受28×28大小。此时的大小为28×28×3。
  • 计算灰度。.mean方法是求平均值的方法,用它我们把第3维的颜色转换为灰度。考虑到我们的图片是黑白的,它在3个颜色上应该是一样的,所以在这里我们也可以用.min.max(最小值、最大值)来计算灰度。此时的大小是28×28。
  • 我们的模型接受的必须是多个图片,所以我们用.expandDims加上一维。此时大小为1×28×28。这里可以用.reshape([1, 28, 28])来达到同样的效果。
  • toFloat,把tensor的元素转换为Float类型。
  • 记住我们的模型在处理mnist输入之前的时候曾做过一个除以255的操作,把灰度转为了0~1之间的小数,这里我们也要做一个同样的处理.div(255.0)

现在我们在canvas上写数据,再点击recognize按钮,就能在浏览器的控制台里看到预测的结果:

Array(10) [ 2.229090443993517e-15, 1.264737121454973e-12, 6.231850036009234e-10, 0.9999980926513672, 7.358067470207216e-14, 7.870837634982308e-7, 3.1836545118929527e-13, 6.341550395916329e-9, 8.096231454146618e-7, 1.0121870008816813e-10 ]

这个长度为10的数组表示模型预测canvas上的图片是0~9之间每个数字的概率。在我的项目里我还加了一个直方图表来表示这个数据,本文略过此处。

其他

一些其他的值得注意的地方:

  • 在这个demo中,训练使用的是python,tf.js只用于预测。而tf.js本身也可以用于训练数据。只是在浏览器里做训练意义不大。
  • 我们的网页是通过直接打开index.html来打开的。事实上,我们也可以把web文件夹下的index.html和model.js放在flask服务的静态资源路径里,这样就能通过url来访问网页了,并且这样flask也不用启用CORS,因为没有跨域。

所有的代码都在这里:digits-recognition-tfjs。作者十分感谢这篇博客:Recognizing Digits using TensorFlow.js in Google Chrome,它对本文启发很大。

如果你觉得这篇文章有帮助的话,记得赞赏。

written by CC, with love.

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,736评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,167评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,442评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,902评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,302评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,573评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,847评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,562评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,260评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,531评论 2 245
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,021评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,367评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,016评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,068评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,827评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,610评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,514评论 2 269

推荐阅读更多精彩内容