利用Tensorflow训练自己的分类器

转自:

https://www.cnblogs.com/andre-ma/p/8458172.html

【写在前面】

用Tensorflow(TF)已实现好的卷积神经网络(CNN)模型来训练自己的数据集,验证目前较成熟模型在不同数据集上的准确度,如Inception_V3, VGG16,Inception_resnet_v2等模型。本文验证Inception_resnet_v2基于菜场实拍数据的准确性,测试数据为芹菜、鸡毛菜、青菜,各类别样本约600张,多个菜场拍摄,不同数据源。


补充:自己当初的计划是用别人预训练好的模型来再训练自己的数据集已使可以完成新的分类任务,但必须要修改代码改网络结构,并使用迁移学习(Fine-tune)

本文记录了其间的工作过程 ,  相信也会有一些帮助的  : )


测试环境:Centos7.3-64位  python3.5.4(Anaconda) 


 目录

一.准备

1.安装python

2.安装tensorflow 

3.下载TF-slim图像库

4.准备数据

5.下载模型

二.训练

1.读入数据

2.构建模型

3.开始训练

4.执行脚本,训练自己的数据

5.可视化log

【问题】 tensorboard版本已更新,找不到对应包

三.验证

四.测试



一.准备

1.安装python

推荐Anaconda,可创建虚拟环境,用conda命令易实现虚拟环境管理、包管理,安装包时会查出所有依赖包并一共一键安装, 链接:https://www.anaconda.com/download/

2.安装tensorflow 

进入当下Anaconda的运行环境,我安装的是python2.7版,并创建3.5虚拟环境

conda create -n py35 python=3.5    【py35是虚拟环境的名称; 输入y 安装】

source activate py35  【激活py35环境】

conda install tensorflow  【安装tensorflow-cpu版,有GPU可安装cpu版】

3.下载TF-slim代码库


cd  $WORKSPACE  【目录跳转到自己的工作目录下】

git clone https://github.com/tensorflow/models/

4.准备数据

对所有训练样本按不同样本类别存在不同文件夹下

zsy_train|---jimaocai

  |---0.jpg

  |---  ...|---qc|---qingcai

下面的代码是为了生成list.txt , 把不同文件夹下的图片和 数字label对应起来

1import os 2class_names_to_ids = {'jimaocai': 0,'qc': 1,'qingcai': 2} 3data_dir ='flower_photos/' 4output_path ='list.txt' 5fd = open(output_path,'w') 6forclass_namein class_names_to_ids.keys(): 7images_list = os.listdir(data_dir + class_name) 8forimage_namein images_list: 9fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))10fd.close()

为了方便后期查看label标签,也可定义labels.txt

jimaocai

qc

qingcai

随机生成训练集和验证集(在总量中随机选取350个样本作为验证集)

1import random 2_NUM_VALIDATION = 350 3_RANDOM_SEED = 0 4list_path ='list.txt' 5train_list_path ='list_train.txt' 6val_list_path ='list_val.txt' 7fd = open(list_path) 8lines = fd.readlines() 9fd.close()10random.seed(_RANDOM_SEED)11random.shuffle(lines)12fd = open(train_list_path,'w')13forlinein lines[_NUM_VALIDATION:]:14    fd.write(line)15fd.close()16fd = open(val_list_path,'w')17forlinein lines[:_NUM_VALIDATION]:18    fd.write(line)19fd.close()

生成TFRecord数据

import sys# sys.path.insert(0, '../models/slim/')  models-master researchsys.path.insert(0,'./models/research/slim/')#把后面的路径插入到系统路径中 idx=0fromdatasetsimport dataset_utilsimport mathimport osimport tensorflow as tf#  根据list路径  把数据转化为TFRecord# def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):  defconvert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=3):     

    fd = open(list_path)

    lines = [line.split()forlinein fd]

    fd.close()

    num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))

    with tf.Graph().as_default():

        decode_jpeg_data = tf.placeholder(dtype=tf.string)

        decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)

        with tf.Session('') as sess:

            forshard_idin range(_NUM_SHARDS):

                output_path = os.path.join(output_dir,#                    'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))'data_{:03}-of-{:03}.tfrecord'.format(shard_id, _NUM_SHARDS))

                tfrecord_writer = tf.python_io.TFRecordWriter(output_path)

                start_ndx = shard_id * num_per_shard

                end_ndx = min((shard_id + 1) * num_per_shard, len(lines))

                foriin range(start_ndx, end_ndx):

                    sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(

                        i + 1, len(lines), shard_id))

                    sys.stdout.flush()

                    image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]),'rb').read()

                    image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})

                    height, width = image.shape[0], image.shape[1]

                    example = dataset_utils.image_to_tfexample(

                        image_data, b'jpg', height, width, int(lines[i][1]))

                    tfrecord_writer.write(example.SerializeToString())

                tfrecord_writer.close()

    sys.stdout.write('\n')

    sys.stdout.flush()


os.system('mkdir -p train')

convert_dataset('list_train.txt','zsy_train','train/')

os.system('mkdir -p val')

convert_dataset('list_val.txt','zsy_train','val/')

得到的文件夹结构如下

WORKSPACE

├── zsy_train

├── labels.txt

├── list_train.txt

├── list.txt

├── list_val.txt

├── train

│  ├── data_000-of-003.tfrecord

│  ├── ...

│  └── data_002-of-003.tfrecord

└── val

    ├── data_000-of-003.tfrecord

    ├── ...

    └── data_002-of-003.tfrecord

5.下载模型

官方提供了预训练,这里以Inception-ResNet-v2以例

cd $WORKSPACE/checkpointswgethttp://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gztarzxf inception_resnet_v2_2016_08_30.tar.gz


二.训练

1.读入数据

读入自己的数据,需要把下面代码写入models/slim/datasets/dataset_classification.py

import osimport tensorflow as tf

slim = tf.contrib.slimdefget_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):

    file_pattern = os.path.join(dataset_dir, file_pattern)

    keys_to_features = {

        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),

        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),

        'image/class/label': tf.FixedLenFeature(

            [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),

    }

    items_to_handlers = {

        'image': slim.tfexample_decoder.Image(),

        'label': slim.tfexample_decoder.Tensor('image/class/label'),

    }

    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

    items_to_descriptions = {

        'image':'A color image of varying size.',

        'label':'A single integer between 0 and '+ str(num_classes - 1),

    }

    labels_to_names = None

    iflabels_to_names_pathisnot None:

        fd = open(labels_to_names_path)

        labels_to_names = {i : line.strip()fori, linein enumerate(fd)}

        fd.close()

    return slim.dataset.Dataset(

            data_sources=file_pattern,

            reader=tf.TFRecordReader,

            decoder=decoder,

            num_samples=num_samples,

            items_to_descriptions=items_to_descriptions,

            num_classes=num_classes,

            labels_to_names=labels_to_names)

2.构建模型

构建模型取决于个人欲构建什么样的模型,官方都有对应模型的下载链接,只需把对应下载(下载链接:https://github.com/tensorflow/models/tree/master/research/slim)好的模型解压放入到checkpoints中即可

3.开始训练

由于是用已有模型训练自己的数据集,故需对原工程代码做适当调整。

fromdatasetsimportdataset_factory

改为:

fromdatasetsimportdataset_classification

dataset = dataset_factory.get_dataset(

    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

改为:

dataset = dataset_classification.get_dataset(

    FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(

    'dataset_dir', None,'The directory where the dataset files are stored.')

后加入:

tf.app.flags.DEFINE_integer(

    'num_samples', 1781,'Number of samples.')

tf.app.flags.DEFINE_integer(

    'num_classes', 3,'Number of classes.')

tf.app.flags.DEFINE_string(

    'labels_to_names_path', None,'Label names file path.')

 4.执行脚本,训练自己的数据

cd $WORKSPACE/models/slim    #跳转到工作环境目录

python train_image_classifier.py \    #运行脚本,后面跟的系统参数

    --train_dir=/root/workspace_mrt/model_lab/train_logs \  #train_log目录,当模型训练时,可用tensorboard命令指定该目录,动态监测

    --dataset_dir=../../../train \    #训练数据集  里面是转换好的TFRecord格式

    --num_samples=1781 \          #训练样本数,即值train_set中的总样本数,不包括valid中随机抽取350个样本

    --num_classes=3 \            #样本类别数

    --labels_to_names_path=../../../labels.txt \ 

    --model_name=inception_resnet_v2 \

    --checkpoint_path=../../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \      #指定模型位置

    --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \

    --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \

    --clone_on_cpu=True        #cpu训练必须加上该参数

#fine-tune要把 --checkpoint_path,--checkpoint_exclude_scopes,--trainable_scopes 加上

5.可视化log

为了可视化训练时的loss或其他指标,可用tensorboard,如下命令

tensorboard --logdir=${TRAIN_DIR}

在本教程中,对应执行下面命令

tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs

【问题】 tensorboard版本已更新,找不到对应包

当执行

tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs

时,得到如下错误

ImportError: No module named'tensorflow.tensorboard.tensorboard'

究其原因,是因为在tensorflow更新时,包的位置和所属关系改变了。执行以下代码,可解决该问题。

cd /root/anaconda2/envs/py35/bin    #跳转到对应python环境的bin目录下,修改tensorboard执行脚本代码,使之适应当前版本

vim tensorboard

import tensorflow.tensorboard.tensorboard

修改为:

import tensorboard.main

sys.exit(tensorflow.tensorboard.tensorboard.main())

修改为:sys.exit(tensorboard.main.main())

wq保存,退出,重新执行

tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs

命令,无报错。根据日志提示,进入ip:6006进入tensorboard界面。

三.验证

使用自己的数据集,需修改models/slim/eval_image_classifier.py

from datasets import dataset_factory

改为:

from datasets import dataset_classification

dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

改为:

dataset = dataset_classification.get_dataset(

    FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(

    'dataset_dir', None,'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(

    'num_samples', 350,'Number of samples.')

tf.app.flags.DEFINE_integer(

    'num_classes', 3,'Number of classes.')

tf.app.flags.DEFINE_string(

    'labels_to_names_path', None,'Label names file path.')

验证时执行以下命令即可:

python eval_image_classifier.py \

    --checkpoint_path=../../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \

    --eval_dir=/root/workspace_mrt/model_lab/eval_logs \

    --dataset_dir=../../../val \

    --num_samples=350 \

    --num_classes=3 \

    --model_name=inception_resnet_v2

可以一边训练一边验证,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

tensorboard --logdir ../../../eval_logs/ --port 6007


四.测试

参考models/slim/eval_image_classifier.py,可编写批量读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

from__future__import absolute_importfrom__future__import divisionfrom__future__import print_functionimport osimport jsonimport mathimport timeimport numpy as npimport tensorflow as tffromnetsimport nets_factoryfrompreprocessingimport preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(

    'master','','The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(

    'checkpoint_path', None,

    'The directory where the model was written to or an absolute path to a ''checkpoint file.')

tf.app.flags.DEFINE_string(

    'test_list','','Test image list.')

tf.app.flags.DEFINE_string(

    'test_dir','.','Test image directory.')

tf.app.flags.DEFINE_integer(

    'batch_size', 16,'Batch size.')

tf.app.flags.DEFINE_integer(

    'num_classes', 3,'Number of classes.')

tf.app.flags.DEFINE_integer(

    'labels_offset', 0,

    'An offset for the labels in the dataset. This flag is primarily used to ''evaluate the VGG and ResNet architectures which do not use a background ''class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(

    'model_name','inception_resnet_v2','The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(

    'preprocessing_name', None,'The name of the preprocessing to use. If left ''as `None`, then the model_name flag is used.')

tf.app.flags.DEFINE_integer(

    'test_image_size', None,'Eval image size')

FLAGS = tf.app.flags.FLAGSdef main(_):

    ifnot FLAGS.test_list:

        raiseValueError('You must supply the test list with --test_list')

    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default():

        tf_global_step = slim.get_or_create_global_step()

        ##################### Select the model #####################network_fn = nets_factory.get_network_fn(

            FLAGS.model_name,

            num_classes=(FLAGS.num_classes - FLAGS.labels_offset),

            is_training=False)

        ###################################### Select the preprocessing function ######################################preprocessing_name = FLAGS.preprocessing_nameor FLAGS.model_name

        image_preprocessing_fn = preprocessing_factory.get_preprocessing(

            preprocessing_name,

            is_training=False)

        test_image_size = FLAGS.test_image_sizeor network_fn.default_image_size

        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):

            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)

        else:

            checkpoint_path = FLAGS.checkpoint_path

        batch_size = FLAGS.batch_size

        tensor_input = tf.placeholder(tf.float32, [None, test_image_size, test_image_size, 3])

        logits, _ = network_fn(tensor_input)

        logits = tf.nn.top_k(logits, 5)

        config = tf.ConfigProto()

        config.gpu_options.allow_growth = True

        test_ids = [line.strip()forlinein open(FLAGS.test_list)]

        tot = len(test_ids)

        results = list()

        with tf.Session(config=config) as sess:

            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver()

            saver.restore(sess, checkpoint_path)

            time_start = time.time()

            foridxin range(0, tot, batch_size):

                images = list()

                idx_end = min(tot, idx + batch_size)

                print(idx)

                foriin range(idx, idx_end):

                    image_id = test_ids[i]

                    test_path = os.path.join(FLAGS.test_dir, image_id)

                    image = open(test_path,'rb').read()

                    image = tf.image.decode_jpeg(image, channels=3)

                    processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)

                    processed_image = sess.run(processed_image)

                    images.append(processed_image)

                images = np.array(images)

                predictions = sess.run(logits, feed_dict = {tensor_input : images}).indices

                foriin range(idx, idx_end):

                    print('{} {}'.format(image_id, predictions[i - idx].tolist())

            time_total = time.time() - time_start

            print('total time: {}, total images: {}, average time: {}'.format(

                time_total, len(test_ids), time_total / len(test_ids)))if__name__=='__main__':

    tf.app.run()

测试时执行以下命令即可:

CUDA_VISIBLE_DEVICES="0" python test_image_classifier.py \

    --checkpoint_path=../../../train_logs/ \

    --test_list=../../../list_val.txt \

    --test_dir=../../../val \

    --batch_size=16 \

    --num_classes=3 \

    --model_name=inception_resnet_v2


【参考】

https://lijiancheng0614.github.io/2017/06/29/2017_06_29_TensorFlow-Slim-image-classification/#%E9%AA%8C%E8%AF%81

https://github.com/tensorflow/models/tree/master/research/slim

【总结】

1.使用Tensorflow预训练模型(TF-slim)可以快速地测试已有成熟模型在不同数据集上的效果,且有利于数据代码架构和核心代码

2.若要自己实现已有成熟的网络模型,包括数据读取、训练、验证、测试,调优,则可用Tensorflow底层API或基于TF的高级API(TensorLayer  TFLearn  TF-slim)实现。从目前了解情况看,TensorLayer更好用,在DeepLearning中迁移学习更是如此

3.TensorFlow是一个非常庞大的架构,新手学习成本太高且难以摸清方向,故此借助下载已有models,调整参数和少量代码有助于进一步认识TensorFlow的实现原理,若想要更深入研究,手写CNN,看源码等都是不可避免的

4.当多看TensorFlow源码,迁移学习的一大利器!

最后附上自己的命令:

记得修改test_image_classifier.py修改topK


©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,835评论 4 364
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,598评论 1 295
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,569评论 0 244
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,159评论 0 213
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,533评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,710评论 1 222
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,923评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,674评论 0 203
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,421评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,622评论 2 245
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,115评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,428评论 2 254
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,114评论 3 238
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,097评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,875评论 0 197
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,753评论 2 276
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,649评论 2 271

推荐阅读更多精彩内容