PyTorch:数据加载和预处理

Github地址
简书地址
CSDN地址

此教程翻译自PyTorch官方教程

作者: Sasank Chilamkurthy

在解决任何机器学习问题上,在准备数据上会付出很大努力。PyTorch 提供了许多工具, 使数据加载变得简单,希望能使你的代码更具可读性。本教程中,我们将看到图和从一个不重要的数据集中加载和预处理/增强数据。

要运行本教程,请确保已安装一下软件包:

  1. scikit-image: 用于图像 IO 和 变换
  2. pandas: 更简单的 csv 解析
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

我们将要处理的数据集是面部姿势,意味着一张人脸将像下面这样被标注:


landmarked_face2

每张人脸总共有68个不同的地方被标注。

注意:
数据下载地址为 https://download.pytorch.org/tutorial/faces.zip, 图像位于名为“faces/“的目录中。这个数据集实际上是通过对来自 imagenet 的几张标注为 ‘face' 的图片应用优秀的 dlib 的姿态估计来生成的。

数据集带有一个 csv 标注文件,里面的标注内容看起来像下面这样:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们快速读取 csv 文件,并把标记数据保存在一个(N, 2)的数组中,其中 N 是特征点的数量。

landmarks_frame = pd.read_csv("./data//faces/face_landmarks.csv")
n = 65
img_name = landmarks_frame.ix[n, 0]
landmarks = landmarks_frame.ix[n, 1:].as_matrix().astype('float')
landmarks = landmarks.reshape(-1, 2)

print("Image name: {}".format(img_name))
print("Landmarks shape: {}".format(landmarks.shape))
print("First 4 Landmarks: {}".format(landmarks[:4]))

输出:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[ 32.  65.]
 [ 33.  76.]
 [ 34.  86.]
 [ 34.  97.]]

让我们写一个简单的帮主函数来显示图像及其特征点,并用他来显示一个样本。

def show_landmarks(image, landmarks):
   """SHow image with landmarks"""
   plt.imshow(image)
   plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker=".", c="r")

plt.figure()
img = io.imread(os.path.join("./data/faces/", img_name))
show_landmarks(io.imread(os.path.join("./data/faces/", img_name)), landmarks)
plt.show()

输出:


sphx_glr_data_loading_tutorial_001

注意:要得到以上结果,请把 plt.ion() 注释掉。

Dataset 类

torch.utils.data.Dataset 是一个表示数据集的抽象类。你自定义的数据集类应该继承自 Dataset 并重写如下方法:

  • __len__: 返回数据集的大小, len(dataset)
  • __getitem__: 是数据集支持索引操作, dataset[i]

让我们维我们的人脸特征点数据集创建一个数据集类。我们将在 __init__ 中读取 csv, 但是让读取图片的操作在 __getitem__ 中进行。这是内存高效的,因为所有的图像不是一次存储在内存中,而是根据需要进行读取。

我们数据集的样本将是一个字典{'image': image, 'landmarks': landmarks}。我们的数据集将接受一个可选参数transform’ 以便可以对样本应用任何需要的处理。我们将在下一节看到transform` 的好处。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.ix[idx, 1:].as_matrix().astype('float')
        landmarks = landmarks.reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

让我们初始化这个类的实例,并在数据样本上迭代。我们讲打印开始4个样本的大小并显示他们的特征点。

face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                    root_dir='faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break
sphx_glr_data_loading_tutorial_002

输出:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

Transform (变换)

从上面的例子我们可以看到一个问题:样本的尺寸不一样。大部分的神经网络希望一个固定大小的图像。因此,我们需要写一些预处理代码。让我们来创建三种变换:

  • Rescale: 缩放图像
  • RandomCrop: 随机剪裁图像,这是一种数据增强的方法
  • ToTensor: 把 numpy 图像转换为 PyTorch 图像(我们需要交换轴)

我们将把它们写成一个可调用的类而不是函数,所以变换所需的参数不必在每次调用时都传递。为此,我们只需实现 __call__ 方法,如果需要可以实现 __init__ 方法。我们可以向下面这样使用他们:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

请观察下面的变换是如何应用在图像和特征点上的。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

组合变换

现在,我们应用这些变换到我们的样本上。

假如我们想先把图像的较短的一边缩放到256,然后从中随机剪裁一个224*224大小的图像。即我们想要组合 RescaleRandomCrop 两个变换。

torchvision.transforms.Compose 是一个简单的可调用类,允许我们来组合多个变换

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()
sphx_glr_data_loading_tutorial_003

迭代数据集

我们把这些放在一个来创建一个包含组合变换的数据集。总之,每当这个数据集被采样时执行一下操作:

  • 即时从文件中读取图像。
  • 对图像应用变换。
  • 由于其中一个变换是随机的,因此数据的采样得到增强。

我们可以使用和之前一样的 for i in range 循环来迭代创建的数据集。

transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                           root_dir='faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break

输出:

0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

但是,通过使用简单的for循环遍历数据,我们将失去许多功能。特别是我们错过了:

  • 批处理数据
  • 打乱数据
  • 使用多线程并行加载数据

torch.utils.data.DataLoader 是一个提供以上所有的功能的迭代器。下面使用的参数应该是清楚的。其中一个又去的参数是 collate_fn。你可以指定如何使用 collate_fn 对样本进行批处理。但是,对大多数情况来说,默认的自动分页应该可以正常工作的很好。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                    landmarks_batch[i, :, 1].numpy(),
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
sphx_glr_data_loading_tutorial_004

输出:

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记:torchvision

在此教程中,我们看到了如何写和使用数据集(dataset),变换(transform)和数据加载器(dataloader)。
torchvision 包提供了一些常见数据集和变换。你甚至可能不需要编写自定义类。torchvision 提供了一个更通过的数据集: ImageFolder。它假设图像按以下方式组织:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中 'ant','bees'等是类别标签,类似的通用的操作于 PIL.Image 的变换,如RandomHOrizontalFlipScale 也是可以获取的。你可以向下面一样使用这些变换来写一个数据加载器。

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

有关训练代码的示例,请阅读 迁移学习章节。

Python 源码
Jupyter

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,425评论 4 361
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,058评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,186评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,848评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,249评论 3 286
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,554评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,830评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,536评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,239评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,505评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,004评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,346评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,999评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,060评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,821评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,574评论 2 271
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,480评论 2 267

推荐阅读更多精彩内容