Pytorch 数据加载器: Dataset 和 DataLoader

为什么要用?

习惯于自己实现业务逻辑的每一步,以至于没有意识去寻找框架本身自有的数据预处理方法,Pytorch的Dataset 和 DataLoader便于加载和迭代处理数据,并且可以傻瓜式实现各种常见的数据预处理,以供训练使用。

调包侠

from torch.utils.data.dataset import Dataset, DataLoader
from torchvision import transforms  ##可方便指定各种transformer,直接传入DataLoader

Dataset 和 DataLoader是什么?

Dataset是一个包装类,可对数据进行张量(tensor)的封装,其可作为DataLoader的参数传入,进一步实现基于tensor的数据预处理。

如何处理自己的数据集?

很多教程里分两种情况:数据同在一个文件夹;数据按类别分布在不同文件夹。其实刚开始我是一头雾水,后来总结后发现,两种情况均可用一种方法来处理,即:只要有一份文件,记录图像数据路径及对应的标签即可,如下所示:

record.txt 示例:
    pic_path                          label
./pic_01/aaa.bmp                        1
./pic_22/bbb.bmp                        0
./pic_03/ccc.bmp                        3
./pic_01/ddd.bmp                        1
            ...

其实有了上面的一份数据对照表文件,即可不用管是否在同一文件夹或是不同文件夹的情况,我自己感觉是要方便一些。下面就按照这种方法来介绍如何使用。

第一步:实现MyDataset类

既然是要处理自己的数据集,那么一般情况下还是写一个自己的Dataset类,该类要继承Dataset,并重写 __ init __() 和 __ getitem __() 两个方法。

例如:
class MyDataset(Dataset):
    def __init__(self, record_path, is_train=True):
        ## record_path:记录图片路径及对应label的文件
        self.data = []
        self.is_train = is_train
        with open(record_path) as fp:
            for line in fp.readlines():
                if line == '\n':
                    break
                else:
                    tmp = line.split("\t")
                    ## tmp[0]:某图片的路径,tmp[1]:该图片对应的label
                    self.data.append([tmp[0], tmp[1]])
        # 定义transform,将数据封装为Tensor
        self.transformations = transforms.Compose([transforms.ToTensor()])

    # 获取单条数据
    def __getitem__(self, index):
        img = self.transformations (Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
        label = int(self.data[index][1])
        return img, label

    # 数据集长度
    def __len__(self):
        return len(self.data)

上面是一个简单的MyDataset类,仅依赖记录了图像位置以及相应label的record文件,实现对数据集的读取和Tensor的转换

当然,根据个人对数据预处理的需求不同,该类的实现可进一步完善,例如:

class MyDataset(Dataset):
    def __init__(self, base_path, is_train=True):
        self.data = []
        self.is_train = is_train
        with open(base_path) as fp:
            for line in fp.readlines():
                if line == '\n':
                    break
                else:
                    tmp = line.split("\t")
                    self.data.append([tmp[0], tmp[1]])
        ## transforms.Normalize:对R G B三通道数据做均值方差归一化,因此给出下方三个均值和方差
        normMean = [0.49139968, 0.48215827, 0.44653124]
        normStd = [0.24703233, 0.24348505, 0.26158768]
        normTransform = transforms.Normalize(normMean, normStd)
        ## 可由 transforms.Compose([transformer_01, transformer_02, ...])实现一些数据的处理和增强
        self.trainTransform = transforms.Compose([       ## train训练集处理
            transforms.RandomCrop(32, padding=4),        ## 图像裁剪的transforms
            transforms.RandomHorizontalFlip(p=0.5),      ## 以50%概率水平翻转
            transforms.ToTensor(),                       ## 转为Tensor形式
            normTransform                                ## 进行 R G B数据归一化
        ])
        ## 测试集的transforms数据处理
        self.testTransform = transforms.Compose([  
            transforms.ToTensor(),
            normTransform
        ])

    # 获取单条数据
    def __getitem__(self, index):
        img = self.trainTransform(Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
        if not self.is_train:
            img = self.testTransform(Image.open(self.data[index][0]).resize((256, 256)).convert('RGB'))
        label = int(self.data[index][1])
        return img, label

    # 数据集长度
    def __len__(self):
        return len(self.data)

或许已经看出来了,所有可能的数据处理或数据增强操作,都可通过transforms来进行调用与封装,是不是一下变得很方便呢!

第二步:将MyDataset装入DataLoader中

MyDataset类中的init方法要求传入记录数据路径及label的文件,因此可如下所示进行操作:

import MyDataset
train_data = MyDataset.MyDataset("./train_record.txt")
test_data = myDataset.MyDataset("./test_record.txt")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dataset=train_data,batch_size=64,shuffle=True,**kwargs)
testLoader = DataLoader(dataset=test_data,batch_size=64,shuffle=False, **kwargs)

这样,便生成了trainLoader 和testLoader

第三步:在训练中使用DataLoader
for epoch in range(1, args.nEpochs + 1):
     ## 定义好的train方法
     train(args, epoch, model, trainLoader, optimizer)
     ## 定义好的val方法,用于测试或验证
     val(args, epoch, model, testLoader, optimizer)

最后

以上便是使用 Dataset和DataLoader处理自己数据集的通用方法,当然本次仅记录了图片数据的使用方法,后续记录文本数据处理方法。

彩蛋

ooh~~ 那么对于Pytorch自带数据集如果处理呢?
若直接使用 CIFAR10 数据集,可以如下处理:

import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normTransform
    ])
testTransform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dset.CIFAR10(root='cifar', train=True, download=True,
                     transform=trainTransform),batch_size=64, shuffle=True, **kwargs)
testLoader = DataLoader(dset.CIFAR10(root='cifar', train=False, download=True,
                     transform=testTransform),batch_size=64, shuffle=False, **kwargs)

其实也就是 torchvision.datasets将这些共用数据集本身就做了 Dataset类的封装,因此直接调用,传入你想要的transforms,再丢给DataLoader即可。

转载注明出处:https://www.jianshu.com/p/b558c538eac2

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 156,265评论 4 359
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 66,274评论 1 288
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 106,087评论 0 237
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,479评论 0 203
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 51,782评论 3 285
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,218评论 1 207
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,594评论 2 309
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,316评论 0 194
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 33,955评论 1 237
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,274评论 2 240
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,803评论 1 255
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,177评论 2 250
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,732评论 3 229
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 25,953评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,687评论 0 192
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,263评论 2 267
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,189评论 2 258