[MXnet] 如何将数据集加载到MXnet中

MXnet的学习笔记,这次主要是使用MXnet提供的example模型进行训练时如何加载数据集的介绍。步骤基本上按照MXNet Python Data Loading API
有关MXnet在OSX下的编译安装,可以看这里Mac下编译安装MXNet
有关MXnet提供的example的综述介绍<-在这里。

Sample iterator for data loading


在浏览完MXnet提供的example后想要在自己的机器上跑一下简单的数据集看看结果。因为现在只是装在自己的MBA上,没有装CUDA和OpenMP,也没有使用GPU训练,因此只能跑一跑简单的数据集。MXnet的Image Classification Example中的样例都比较完整,使用步骤也很详细,训练最基本的MNIST数据集基本上不需要多余的工作量,只要能联网下载MNIST数据集(或者自己有数据集的话移动到对应文件夹下)就可以直接训练,效果也挺不错:

→ python train_mnist.py 
2016-05-23 08:51:41,616 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)
[08:51:45] src/io/iter_mnist.cc:91: MNISTIter: load 60000 images, shuffle=1, shape=(128,784)
[08:51:46] src/io/iter_mnist.cc:91: MNISTIter: load 10000 images, shuffle=1, shape=(128,784)
2016-05-23 08:51:46,460 Node[0] Start training with [cpu(0)]
...
2016-05-23 08:52:02,548 Node[0] Epoch[9] Batch [450]    Speed: 41054.59 samples/sec Train-top_k_accuracy_20=1.000000
2016-05-23 08:52:02,605 Node[0] Epoch[9] Resetting Data Iterator
2016-05-23 08:52:02,605 Node[0] Epoch[9] Time cost=1.470
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-accuracy=0.977464
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.999299
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_10=1.000000
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_20=1.000000

默认的参数为:batch-size=128,初始学习率为0.1(固定学习率,lr_factor_epoch=1),使用最基本的多层感知机MLP进行训练。每个epoch耗时大约1.5秒左右,在10次迭代后测试集的accuracy达到0.977464。
其实在测试的时候看到Validation-accuracy时就有在想指的是cross-validation的accuracy还是test的accuracy,因此这时候就可以先去看看MXnet中到底是怎么读取数据、怎么使用KVstore的。

根据官方文档的介绍,MXnet使用iterator将参数传递给训练模型。这里的iterator会做一些数据预处理,并且生成指定大小的batch输入训练模型。
由于MNIST的数据比较简单,example里面提供了载入MNIST数据集的iterator实现,如下:

def get_iterator(data_shape):
    def get_iterator_impl(args, kv):
        data_dir = args.data_dir
        # 若指定位置没有MNIST数据集则会调用_download()函数联网下载
        if '://' not in args.data_dir:
            _download(args.data_dir)
        # data_shape变量为输入数据的格式。对于MNIST:
        # 若使用MLP进行训练,输入数据为有784个元素的一维向量,data_shape = (784, )
        # 若使用LeNet进行训练,输入数据为一个28*28的矩阵,data_shape = (1, 28, 28)
        # 因此若len(data_shape)不等于3时,设置flat变量为True,即对MNIST每一个输入数据一维扁平化
        flat = False if len(data_shape) == 3 else True

        # 训练集的参数指定
        train           = mx.io.MNISTIter(
            image       = data_dir + "train-images-idx3-ubyte",
            label       = data_dir + "train-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            ## A commonly mistake is forgetting shuffle the image list during packing.
            ## This will lead fail of training.
            ## eg. accuracy keeps 0.001 for several rounds.
            shuffle     = True,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)

        # 测试集的参数指定
        val = mx.io.MNISTIter(
            image       = data_dir + "t10k-images-idx3-ubyte",
            label       = data_dir + "t10k-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)

        return (train, val)
    return get_iterator_impl

train_mnist.py 的main函数里,会调用get_iterator()函数得到输入的iterator,传递给train_model.fit()函数执行真正的训练过程。
在之前的example介绍里有说到,Image Classification(包括后面基于CNN的很多其它网络)的不同网络结构运用在不同的数据集上,最后都是回到调用train_model.fit()函数进行训练。因此输入数据的获取和iterator的定义都在对应的 train_{mnist, cifar10, imagenet}.py 中,最简单的定义就如上面的代码所示。

Build your own iterator


MNIST输入数据的格式类型分为recordio,MNISTcsv。MNIST数据集的参数指定较为简单,上面的例子基本都覆盖到了。有关csv和MNIST数据集的更多参数指定信息<--点击链接。
对于图片数据集(recordio格式的数据),在创建iterator时,一般需要指定的参数有五类,包括:

  1. 数据集参数 (Dataset Param),提供了数据集的基本信息,如数据文件地址、数据形状(即前例中的input_shape)等等。
  2. 批参数 (Batch Param) 提供了形成batch的信息,比如batch size
  3. Augmentation Param 可以设定对数据集预处理的参数,比如mean_image(将图像中的每个像素减去图片像素均值),rand_crop(随机对图像进行部分切割),rand_mirror(随机对图像进行水平对称变换)等等。
  4. 后台参数 (Backend Param) 控制后台线程来隐藏读取数据的开销的相关参数,如preprocess_threads设定后台预读取线程数量,prefetch_buffer设定预读取buffer的大小。
  5. 辅助参数 (Auxiliary Param) 提供用于调试的参数设定,如verbose设定是否要输出parser信息。

具体的参数定义可以看官方文档:I/O API

Use your own data


要使用自己的数据集(或者ImageNet数据集),由于MXnet没有提供类似MNIST和cifar的自动下载和加载脚本将原始数据转换为ImageRecord数据,因此需要自己进行数据格式转换。
不过将数据转换为ImageRecord格式也很简单:

  • 首先将图像存储为压缩过的格式(比如.jpg),以降低数据量。
  • 使用MXnet提供的make_list[./mxnet/tools/make_list.py]工具生成lst文件,lst文件的格式为
integer_image_index \t label_index \t path_to_image

make_list接受的参数包括

  • chunks[int]:将原始数据集分成chunks块,得到chunks个数据量相同但对应数据不同的lst文件,默认值为1。
  • train_ratio[float]:指定每个chunk内用于训练的数据所占的比例,可以设置不同的训练集-测试集比,默认值为1(即所有数据用于训练)。
  • exts[list]:接受的输入数据格式,默认值为{.jpg,.jpeg}。
  • recursive[bool]:若设定为TRUE且原始数据集已经按照label放在了不同的子文件夹中,则make_list会自动为每个子文件夹内的数据标记对应的label_index,否则所有的数据都标注统一 label_index = 0,默认值为FALSE。
  • 使用MXnet提供的im2rec[./mxnet/tools/im2rec.{cc,py}]工具(提供C++版本和Python版本),通过原始数据和lst文件得到ImageRecord格式的数据供MXnet使用。若不指定lst文件则使用与make_list相同的方法先生成lst文件再生成ImageRecord文件。im2rec除了有make_list相同的参数外,在生成ImageRecord部分的参数还有
    • resize[int, default = 0]:等比例缩放图片,将图片短边设置为指定大小。
    • center_crop[bool, default = FALSE]:截取图片中间的方形部分,方形边长为短边长。
    • quality[int, default = 80]:设定图像的质量(.jpg:1-100, .png:1-9)。
    • num_thread[int, default = 1]:若使用多线程进行数据格式转换,则生成图像顺序会与输入list的不同。
    • color[int, default = 1, choice = {-1, 0, 1}]:输入图像的color mode,若为1则直接读如彩色数据,0为灰度模式,-1为使用alpha channel(<--这个应该是图像处理领域的专业知识,我也不是很理解)。
    • encoding[str, default = .jpg, choice = {.jpg, .png}]:图像转换后保存的格式。

这边会遇到一点问题,如果调用im2rec.py的时候提示

No module named cv

的话,网上查询到的原因是没安装openCV(不过其实之前装了……)
那只要把代码中

import cv, cv2

中的cv去掉即可,后续好像只使用到了cv2库中的内容,不需要cv。

然后就可以在MXnet中使用自己的数据集了。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,015评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,262评论 1 292
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,727评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,986评论 0 205
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,363评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,610评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,871评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,582评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,297评论 1 242
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,551评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,053评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,385评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,035评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,079评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,841评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,648评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,550评论 2 270

推荐阅读更多精彩内容