Tensorflow2单机多GPU数据准备与训练说明

前言

能看到这篇文章的,都是富贵让我们相遇。
现在这光景,单GPU都困难,何况多GPU训练。。。

几个需要注意的点

  1. 模型生成部分需要使用tf.distribute.MirroredStrategy
  2. 为了将batch size的数据均等分配给各个GPU的显存,需要通过tf.data.Dataset.from_generator托管数据,从迭代器加载,同时显式关闭AutoShardPolicy。如果不做这一步,显存分配可能会出问题,不仅显存会爆,还可能过程中的validation loss计算会出问题。
  3. 为了避免触发tensorflow2在完成以上步骤,训练过程中metrics的计算bug,需要做到如下几点!这个地方是痛点,如果不仔细跟踪,是很难发现的!
    metrics一定设置为binary_accuracy,或者sparse_categorical_accuracy
    不能简单设置为acc
    否则之后会报:as_list() is not defined on an unknown TensorShape的错误
  4. 之所以使用生成器动态产生训练数据,不仅仅是为了避免一次性加载训练数据,直接吃爆显存,还因为需要实时对训练数据做数据增强与变换,增加模型的鲁棒性。

代码部分

模型生成与编译部分

直接看tf.distribute.MirroredStrategy的用法,损失函数,优化函数的根据自己习惯来。但是metrics一定不能选择acc!

gpus = tf.config.list_physical_devices('GPU')
batchsize = 8
print('apply: Adam + weighted_bce_dice_loss_v1_7_3')
if len(gpus) > 1:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(device=gpu, enable=True)
    batchsize *= len(gpus)
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = table_line.get_model(input_shape=(512, 512, 3),
                                     is_resnest_unet=is_resnest_unet,
                                     is_swin_unet=is_swin_unet,
                                     resnest_pretrain_model=resnest_pretrain_model)
        # apply custom loss
        model.compile(
            optimizer=Adam(
                lr=0.0001),
            loss=weighted_bce_dice_loss_v1_7_3,
            metrics=['binary_accuracy'])
else:
    model = table_line.get_model(input_shape=(512, 512, 3),
                                 is_resnest_unet=is_resnest_unet,
                                 is_swin_unet=is_swin_unet,
                                 resnest_pretrain_model=resnest_pretrain_model)
    model.compile(
        optimizer=Adam(
            lr=0.0001),
        loss=weighted_bce_dice_loss_v1_7_3,
        metrics=['binary_accuracy'])
print('batch size: {0}, GPUs: {1}'.format(batchsize, gpus))

数据迭代器生成部分

def makeDataset(generator_func,
                data_list,
                line_path,
                batchsize,
                draw_line,
                is_raw,
                need_rotate,
                only_flip,
                is_wide_line,
                strategy=None):
    # Get amount of files
    ds = tf.data.Dataset.from_generator(generator_func,
                                        args=[data_list, line_path, batchsize,
                                              draw_line, is_raw, need_rotate,
                                              only_flip, is_wide_line],
                                        output_types=(tf.float64, tf.float64))
    # Make a dataset from the generator. MAKE SURE TO SPECIFY THE DATA TYPE!!!
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
    ds = ds.with_options(options)

    # Optional: Make it a distributed dataset if you're using a strategy
    if strategy is not None:
        ds = strategy.experimental_distribute_dataset(ds)

    return ds

获取training与validation数据获取的迭代器
其中gen是生成数据的方程,其余参数, 除了最后一个strategy参数,都是生成数据方程所需的参数

training_ds = makeDataset(gen,
                          data_list=trainP,
                          line_path=line_path,
                          batchsize=batchsize,
                          draw_line=False,
                          is_raw=is_raw,
                          need_rotate=need_rotate,
                          only_flip=only_flip,
                          is_wide_line=is_wide_line,
                          strategy=None)
validation_ds = makeDataset(gen,
                            data_list=testP,
                            line_path=line_path,
                            batchsize=batchsize,
                            draw_line=False,
                            is_raw=is_raw,
                            need_rotate=need_rotate,
                            only_flip=only_flip,
                            is_wide_line=is_wide_line,
                            strategy=None)

生成数据方程的示例,学过iterate的都明白在说啥

def gen(paths,
        line_path,
        batchsize=2,
        draw_line=True,
        is_raw=False,
        need_rotate=False,
        only_flip: bool = True,
        is_wide_line=False):
    num = len(paths)
    i = 0
    while True:
        # sizes = [512,512,512,512,640,1024] ##多尺度训练
        # size = np.random.choice(sizes,1)[0]
        size = 512
        X = np.zeros((batchsize, size, size, 3))
        Y = np.zeros((batchsize, size, size, 2))
        print(i)
        for j in range(batchsize):
            while True:
                if i >= num:
                    i = 0
                    np.random.shuffle(paths)
                p = paths[i]
                i += 1
                try:
                    if is_raw:
                        img, lines, labelImg = get_img_label_raw(p,
                                                                 line_path,
                                                                 size=(size, size),
                                                                 draw_line=draw_line,
                                                                 is_wide_line=is_wide_line)
                    else:
                        img, lines, labelImg = get_img_label_transform(p,
                                                                       line_path,
                                                                       size=(size, size),
                                                                       draw_line=draw_line,
                                                                       need_rotate=need_rotate,
                                                                       only_flip=only_flip,
                                                                       is_wide_line=is_wide_line)
                    break
                except Exception as e:
                    print(e)
            X[j] = img
            Y[j] = labelImg
        yield X, Y

模型训练部分的代码

训练方法:fit

之前调用数据生成器的训练方法是fit_generator,TF2之后统一用fit方程了

steps参数的写法,重点!

注意steps_per_epoch与validation_steps的写法,batchsize必须与调用makeDataset时,传入的batchsize的值相同,否则无法计算出正确的steps

model.fit(training_ds,
          callbacks=[checkpointer, earlyStopping],
          steps_per_epoch=max(1, len(trainP) // batchsize),
          validation_data=validation_ds,
          validation_steps=max(1, len(testP) // batchsize),
          epochs=300)

推荐阅读更多精彩内容