易 AI - 使用 TensorFlow 和 Labelme 训练自定义 U-NET 图像分割模型

原文:https://makeoptim.com/deep-learning/yiai-unet

前言

先前笔者在使用 TensorFlow Object Detection API Mask R-CNN 训练自定义图像分割模型介绍了如何使用 Mask R-CNN 实现图像分割。不过 Mask R-CNN 网络较为复杂,性能消耗较高,在实际的运用中,如果不是很复杂的分割任务,还有一个较为合适的选择,那就是本文要讲解的 U-NET

介绍 U-NET 的文章很多,不过从自定义数据集模型定义训练预测的文章却寥寥无几。因此,本文旨在通过 一个 Demo 来覆盖各个步骤,让大家快速掌握 U-NET

环境搭建

下载源代码 https://github.com/CatchZeng/tensorflow-unet-labelme 到本地,并进入该目录下。

如果你使用的是 macOS, 你需要在安装前先执行以下命令。

❯ brew install pyqt

执行以下命令,安装虚拟环境 unet

❯ conda create -n unet -y python=3.9 && conda activate unet && pip install -r requirements.txt

数据集

文本还是以使用 TensorFlow Object Detection API Mask R-CNN 训练自定义图像分割模型 中的茶杯(cup)、茶壶(teapot)、加湿器(humidifier) 来做案例

数据标注

数据标注已在使用 TensorFlow Object Detection API Mask R-CNN 训练自定义图像分割模型 阐述,这里就不再赘述。

生成 VOC 数据集

U-NET 的数据集包含原始图像(jpg) 和标签(mask)图像(png),通常使用 VOC 格式来整理数据集。

将标注好的数据存放到 tensorflow-unet-labelmedatasets/train 下,并新建 datasets/labels.txt,内容为分类名称,详见 https://github.com/wkentaro/labelme/tree/main/examples/semantic_segmentation

❯ tree -L 3
.
├── Makefile
├── README.md
├── datasets
│   ├── README.md
│   ├── labels.txt
│   ├── train
│   │   ├── 1.jpg
│   │   ├── 1.json
......
│   │   ├── 9.jpg
│   │   └── 9.json
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py

注:可以参考 https://github.com/CatchZeng/tensorflow-unet-labelme/tree/master/datasets

执行以下命令,生成 VOC 格式数据集。

❯ make voc
❯ tree -L 5
.
├── Makefile
├── README.md
├── datasets
│   ├── README.md
│   ├── labels.txt
│   ├── test
│   │   └── 1.jpg
│   ├── train
│   │   ├── 1.jpg
│   │   ├── 1.json
......
│   │   ├── 9.jpg
│   │   └── 9.json
│   └── train_voc
│       ├── ImageSets
│       │   └── Segmentation
│       │       ├── test.txt
│       │       ├── train.txt   # 训练集图像名称列表
│       │       ├── trainval.txt
│       │       └── val.txt  # 验证集图像名称列表
│       ├── JPEGImages # 原图
│       │   ├── 1.jpg
......
│       │   └── 9.jpg
│       ├── SegmentationClass
│       │   ├── 1.npy
......
│       │   └── 9.npy
│       ├── SegmentationClassPNG # 标签(mask)图
│       │   ├── 1.png
......
│       │   └── 9.png
│       ├── SegmentationClassVisualization
│       │   ├── 1.jpg
......
│       │   └── 9.jpg
│       └── class_names.txt
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py

生成的 datasets/train_voc 便是训练用到的数据集。

训练

打开 unet.ipynb,并选择 Python 解释器为 unet,即可开始训练。

代码详解

数据集分为训练集验证集,分别从 ImageSets/Segmentation/train.txtImageSets/Segmentation/val.txt 读取文件。然后,通过 UnetDataset 类构建成为 tf.keras.utils.Sequence 对象,方便后面通过 model.fit 直接训练。

dataset_path = 'datasets/train_voc'

# read dataset txt files
with open(os.path.join(dataset_path, "ImageSets/Segmentation/train.txt"),
          "r",
          encoding="utf8") as f:
    train_lines = f.readlines()

with open(os.path.join(dataset_path, "ImageSets/Segmentation/val.txt"),
          "r",
          encoding="utf8") as f:
    val_lines = f.readlines()

train_batches = UnetDataset(train_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
                            True, dataset_path)
val_batches = UnetDataset(val_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
                          False, dataset_path)

STEPS_PER_EPOCH = len(train_lines) // BATCH_SIZE
VALIDATION_STEPS = len(val_lines) // BATCH_SIZE // VAL_SUBSPLITS

UnetDataset 类继承自 tf.keras.utils.Sequence 。通过 __getitem__ 方法返回一组 batch_size 的数据,其中包含原图(images)和标签图(targets)。因为模型有固定的 input shape,因此,在 process_data 方法中做了 resize 操作;在训练过程中,还可以加入数据增强,这里使用了一个简单的 flip

class UnetDataset(tf.keras.utils.Sequence):

    def __init__(self, annotation_lines, input_shape, batch_size, num_classes,
                 train, dataset_path):
        self.annotation_lines = annotation_lines
        self.length = len(self.annotation_lines)
        self.input_shape = input_shape
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.train = train
        self.dataset_path = dataset_path

    def __len__(self):
        return math.ceil(len(self.annotation_lines) / float(self.batch_size))

    def __getitem__(self, index):
        images = []
        targets = []
        for i in range(index * self.batch_size, (index + 1) * self.batch_size):
            i = i % self.length
            name = self.annotation_lines[i].split()[0]
            jpg = Image.open(
                os.path.join(os.path.join(self.dataset_path, "JPEGImages"),
                             name + ".jpg"))
            png = Image.open(
                os.path.join(
                    os.path.join(self.dataset_path, "SegmentationClassPNG"),
                    name + ".png"))

            jpg, png = self.process_data(jpg,
                                         png,
                                         self.input_shape,
                                         random=self.train)

            images.append(jpg)
            targets.append(png)

        images = np.array(images)
        targets = np.array(targets)
        return images, targets

    def rand(self, a=0, b=1):
        return np.random.rand() * (b - a) + a

    def process_data(self, image, label, input_shape, random=True):
        image = cvtColor(image)
        label = Image.fromarray(np.array(label))
        h, w, _ = input_shape

        # resize
        image, _, _ = resize_image(image, (w, h))
        label, _, _ = resize_label(label, (w, h))

        if random:
            # flip
            flip = self.rand() < .5
            if flip:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)

        # np
        image = np.array(image, np.float32)
        image = normalize(image)

        label = np.array(label)
        label[label >= self.num_classes] = self.num_classes

        return image, label

模型定义部分,比较简单,跟论文中一样,主要为下采样,上采样,和 concat

这里,笔者参考了 https://www.tensorflow.org/tutorials/images/segmentation ,详细的解析大家可以查看。

def unet_model(output_channels: int):
    inputs = tf.keras.layers.Input(shape=INPUT_SHAPE)

    # Downsampling through the model
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(filters=output_channels,
                                           kernel_size=3,
                                           strides=2,
                                           padding='same')  #64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

Callback 部分,笔者主要使用了 DisplayCallbackModelCheckpointCallback

DisplayCallback 用于训练完一个 epoch 后,显示预测的结果,便于观测模型的效果

ModelCheckpointCallback 用于训练完一个 epoch 后,在 logs 文件夹下 保存权值(模型),并且记录每个 epoch 后的准确率和损失率。这样,用户在训练完后,可以挑选效果比较好的模型

class DisplayCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print('\nSample Prediction after epoch {}\n'.format(epoch + 1))


class ModelCheckpointCallback(tf.keras.callbacks.Callback):

    def __init__(self,
                 filepath,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 period=1):
        super(ModelCheckpointCallback, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn(
                'ModelCheckpoint mode %s is unknown, '
                'fallback to auto mode.' % (mode), RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch + 1, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn(
                        'Can save best model only with %s available, '
                        'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print(
                                '\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                                ' saving model to %s' %
                                (epoch + 1, self.monitor, self.best, current,
                                 filepath))
                        self.best = current
                        if self.save_weights_only:
                            self.model.save_weights(filepath, overwrite=True)
                        else:
                            self.model.save(filepath, overwrite=True)
                    else:
                        if self.verbose > 0:
                            print('\nEpoch %05d: %s did not improve' %
                                  (epoch + 1, self.monitor))
            else:
                if self.verbose > 0:
                    print('\nEpoch %05d: saving model to %s' %
                          (epoch + 1, filepath))
                if self.save_weights_only:
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.model.save(filepath, overwrite=True)

预测模型部分,先将文件 resize 到模型的 INPUT_SHAPE 大小。这里需要注意的是,预测的图不一定是与 INPUT_SHAPE 比例相等的。为了不因为比例问题,导致预测结果不准确,笔者这里在 resize 的时候为图片不在比例的地方,添加了灰色占位边,如下图所示。

然后再预测完之后,再去掉灰边

注:本案例效果图因为是 Demo,所以只是训练了 20 几个 epoch,准确度是一般的,大家在实际应用中,可以多训练下,提高准确度。

def detect_image(image_path):
    image = Image.open(image_path)
    image = cvtColor(image)

    old_img = copy.deepcopy(image)
    ori_h = np.array(image).shape[0]
    ori_w = np.array(image).shape[1]

   # resize 并添加灰边
    image_data, nw, nh = resize_image(image, (INPUT_SHAPE[1], INPUT_SHAPE[0]))

    image_data = normalize(np.array(image_data, np.float32))

    image_data = np.expand_dims(image_data, 0)

    pr = model.predict(image_data)[0]

   ## 去掉灰边
    pr = pr[int((INPUT_SHAPE[0] - nh) // 2) : int((INPUT_SHAPE[0] - nh) // 2 + nh), \
            int((INPUT_SHAPE[1] - nw) // 2) : int((INPUT_SHAPE[1] - nw) // 2 + nw)]

    pr = cv2.resize(pr, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)

    pr = pr.argmax(axis=-1)

    # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
    # for c in range(NUM_CLASSES):
    #     seg_img[:, :, 0] += ((pr[:, :] == c ) * colors[c][0]).astype('uint8')
    #     seg_img[:, :, 1] += ((pr[:, :] == c ) * colors[c][1]).astype('uint8')
    #     seg_img[:, :, 2] += ((pr[:, :] == c ) * colors[c][2]).astype('uint8')
    seg_img = np.reshape(
        np.array(colors, np.uint8)[np.reshape(pr, [-1])], [ori_h, ori_w, -1])

    image = Image.fromarray(seg_img)
    image = Image.blend(old_img, image, 0.7)

    return image

小结

本文,通过一个 Demo 介绍了,如何从制作数据集到训练 U-NET 模型并预测图片整个流程。大家可以自己找一个场景,制作一个自定义的数据集,然后实践一遍,以便更好地掌握。本文就到这里了,咱们下一篇见。

延伸阅读

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,015评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,262评论 1 292
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,727评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,986评论 0 205
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,363评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,610评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,871评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,582评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,297评论 1 242
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,551评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,053评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,385评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,035评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,079评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,841评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,648评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,550评论 2 270

推荐阅读更多精彩内容