机器学习Tensorflow笔记3:Python训练MNIST模型,在Android上实现评估

通常而言我们会通过Python编写代码训练Tensorflow,但是我们训练的数据需要实际应用起来,本文会介绍如何通过Python训练Tensorflow,训练的结果在Android上应用,当前也可以通过传输数据给服务端去识别,然后返回数据,但是这种方式实时性较差,需要上传识别数据,然后等待返回数据,在某些场景下也是适用,可以查看下面的Java中调用文章。

实战

实战的内容是基于MNIST实验,在Android平台实现识别功能。

本文是基于MNIST实验,如果还没有做过MNIST实验,那么可以先看我之前2篇文章
《机器学习Tensorflow笔记1:Hello World到MNIST实验》
《机器学习Tensorflow笔记2:超详细剖析MNIST实验》

1. Python保存训练模型

在MNIST实验中,我们是训练完成模型后马上就调用测试代码,如果我们要应用起来,就不可能在移动端去训练,我们应该把训练好的模型放在手机里面,或者通过URL下载到手机里面,所以我们需要保存我们的训练的模型。

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import gzip
import sys
import struct
import numpy

from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile

train_images_file = "MNIST_data/train-images-idx3-ubyte.gz"
train_labels_file = "MNIST_data/train-labels-idx1-ubyte.gz"
t10k_images_file = "MNIST_data/t10k-images-idx3-ubyte.gz"
t10k_labels_file = "MNIST_data/t10k-labels-idx1-ubyte.gz"


def read32(bytestream):
    # 由于网络数据的编码是大端,所以需要加上>
    dt = numpy.dtype(numpy.int32).newbyteorder('>')
    data = bytestream.read(4)
    return numpy.frombuffer(data, dt)[0]


def read_labels(filename):
    with gzip.open(filename) as bytestream:
        magic = read32(bytestream)
        numberOfLabels = read32(bytestream)
        print(magic)
        print(numberOfLabels)
        labels = numpy.frombuffer(bytestream.read(numberOfLabels), numpy.uint8)
        data = numpy.zeros((numberOfLabels, 10))
        for i in xrange(len(labels)):
            data[i][labels[i]] = 1
        bytestream.close()
    return data


def read_images(filename):
    # 把文件解压成字节流
    with gzip.open(filename) as bytestream:
        magic = read32(bytestream)
        numberOfImages = read32(bytestream)
        rows = read32(bytestream)
        columns = read32(bytestream)
        images = numpy.frombuffer(bytestream.read(numberOfImages * rows * columns), numpy.uint8)
        images.shape = (numberOfImages, rows * columns)
        images = images.astype(numpy.float32)
        images = numpy.multiply(images, 1.0 / 255.0)
        bytestream.close()
        print(magic)
        print(numberOfImages)
        print(rows)
        print(columns)
    return images


# 解析labels的内容,train_labels包含了60000个数字标签,返回60000个数字标签的数组
train_labels = read_labels(train_labels_file)
# print(labels)
train_images = read_images(train_images_file)

test_labels = read_labels(t10k_labels_file)
# print(labels)
test_images = read_images(t10k_images_file)

import tensorflow as tf

x = tf.placeholder("float", [None, 784.],name='input/x_input')
W = tf.Variable(tf.zeros([784., 10.]))
b = tf.Variable(tf.zeros([10.]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder("float",name='input/y_input')
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1200):
    batch_xs = train_images[50 * i:50 * i + 50]
    batch_ys = train_labels[50 * i:50 * i + 50]
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})


correct_prediction = tf.equal(tf.argmax(y, 1, output_type='int32', name='output'),
                              tf.argmax(y_, 1, output_type='int32'))

# correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: test_images, y_: test_labels})

# 保存训练好的模型
# 形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:  # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
    f.write(output_graph_def.SerializeToString())
sess.close()

通过简单的修改代码,就可以轻松实现保存训练模型到本地。

测试导出的模型是否可用
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
from PIL import Image

#模型路径
model_path = 'model/mnist.pb'
#测试图片
testImage = Image.open("data/test_image.png")

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(model_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        # x_test = x_test.reshape(1, 28 * 28)
        input_x = sess.graph.get_tensor_by_name("input/x_input:0")
        output = sess.graph.get_tensor_by_name("output:0")

        #对图片进行测试
        testImage=testImage.convert('L')
        testImage = testImage.resize((28, 28))
        test_input=np.array(testImage)
        test_input = test_input.reshape(1, 28 * 28)
        pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果
        print('模型预测结果为:',pre_num)

2. 配置项目

  1. 在app目录对于的build.gradle添加Gradle依赖,由于so文件很大,所以建议只支持arm,引入Tensorflow后,apk仅仅只增加了4.9MB,如果人工智能当做重要的业务,这个成本是值得的,后续我也会编写Tensorflow Lite的文章,体积更小,更加适合移动设备。
android {
      //...
    buildTypes {
       debug {
            minifyEnabled false
            debuggable = false  
            proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
            ndk {
                abiFilters "armeabi-v7a","x86"
            }
        }
        release {
            minifyEnabled false
            debuggable = false
            proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
            ndk {
                abiFilters "armeabi-v7a"
            }
        }
    }
}
dependencies {
    implementation 'org.tensorflow:tensorflow-android:1.8.0'
}

  1. 把上面保存好的训练模型放到Android项目中的assets文件夹中,同时把需要测试的图片放到drawable文件夹下。
├── main
│   ├── AndroidManifest.xml
│   ├── assets
│   │   └── mnist.pb
│   └── res
│       ├── drawable
│       │   └── test_image.png
test_image.png
image.png

image.png
测试模型
class MainActivity : AppCompatActivity() {

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        val bitmap = BitmapFactory.decodeResource(resources, R.drawable.test_image)
        val tfi = TensorFlowInferenceInterface(assets, "mnist.pb")
        val inputData = bitmapToFloatArray(bitmap, 28f, 28f)
        tfi.feed("input/x_input", inputData, 1, 784)
        val outputNames = arrayOf("output")
        tfi.run(outputNames)
        // 用于存储模型的输出数据
        val outputs = IntArray(1)
        tfi.fetch(outputNames[0], outputs)

        imageView.setImageBitmap(bitmap)
        textView.text = "结果为:" + outputs[0]
    }

    /**
     * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
     * @param bitmap 输入被测试的bitmap图片
     * @param rx 将图片缩放到指定的大小(列)->28
     * @param ry 将图片缩放到指定的大小(行)->28
     * @return   返回归一化后的一维float数组 ->28*28
     */
    private fun bitmapToFloatArray(bitmap: Bitmap, rx: Float, ry: Float): FloatArray {
        var height = bitmap.height
        var width = bitmap.width
        // 计算缩放比例
        val scaleWidth = rx / width
        val scaleHeight = ry / height
        val matrix = Matrix()
        matrix.postScale(scaleWidth, scaleHeight)
        val bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true)
        height = bitmap.height
        width = bitmap.width
        val result = FloatArray(height * width)
        var k = 0
        for (row in 0 until height) {
            for (col in 0 until width) {
                val argb = bitmap.getPixel(col, row)
                val r = Color.red(argb)
                val g = Color.green(argb)
                val b = Color.blue(argb)
                //由于是灰度图,所以r,g,b分量是相等的。
                assert(r == g && g == b)
                result[k++] = r / 255.0f
            }
        }
        return result
    }
}

布局文件

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:padding="10dp"
    android:orientation="vertical">

    <ImageView
        android:id="@+id/imageView"
        android:layout_width="100dp"
        android:layout_height="100dp"
        android:layout_gravity="center"
        android:scaleType="fitXY" />

    <TextView
        android:id="@+id/textView"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_marginTop="20dp"
        android:gravity="center"
        android:text="结果为:" />
</LinearLayout>
结果
image.png
源码

https://github.com/taoweiji/TensorflowAndroidDemo

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 162,408评论 4 371
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 68,690评论 2 307
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 112,036评论 0 255
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,726评论 0 221
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 53,123评论 3 296
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 41,037评论 1 225
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 32,178评论 2 318
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,964评论 0 213
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,703评论 1 250
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,863评论 2 254
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,333评论 1 265
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,658评论 3 263
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,374评论 3 244
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,195评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,988评论 0 201
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 36,167评论 2 285
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,970评论 2 279

推荐阅读更多精彩内容