这么骚!Batch Normalization 还能多卡同步?(附源码解析)

Date: 2020/01/30

Author: CW

前言:

在上一篇文(你真的了解 Batch Normalization 么?)的彩蛋中,CW提到了使用多GPU卡训练的情况下Batch Normalization(BN)可能会带来的问题,目前在很多深度学习框架如 Caffe、MXNet、TensorFlow 和 PyTorch 等,所实现的 BN 都是非同步的(unsynchronized),即归一化操作是基于每个 GPU上的数据独立进行的。

本文会为大家解析 BN 的多卡同步版本,这里简称 SyncBN,首先解释为何需要进行同步,接着为大家揭晓需要同步哪些信息,最后结合基于 Pytorch 实现的代码解析实现过程中的关键部分。


Outline

i). Why Synchronize BN:为何在多卡训练的情况下需要对BN进行同步?

ii). What is Synchronized BN:什么是同步的BN,具体同步哪些东西?

iii). How to implement:如何实现多卡同步的BN?

    (1). 2次同步 vs 1次同步;

    (2). 介绍 torch.nn.DataParallel 的前向反馈;

    (3). 重载 torch.nn.DataParallel.replicate 方法;

    (4). SyncBN 的同步注册机制;

    (5). SyncBN 的前向反馈


Why Synchronize BN:为何在多卡训练的情况下需要对BN进行同步?

对于视觉分类和目标检测等这类任务,batch size 通常较大,因此在训练时使用 BN 没太大必要进行多卡同步,同步反而会由于GPU之间的通信而导致训练速度减慢;

然而,对于语义分割等这类稠密估计问题而言,分辨率高通常会得到更好的效果,这就需要消耗更多的GPU内存,因此其 batch size 通常较小,那么每张卡计算得到的统计量可能与整体数据样本具有较大差异,这时候使用 BN 就有一定必要性进行多卡同步了。

多卡情况下的BN(非同步)

这里再提一点,如果使用pytorch的torch.nn.DataParallel,由于数据被可使用的GPU卡分割(通常是均分),因此每张卡上 BN 层的batch size(批次大小)实际为\frac{BatchSize}{nGPU},下文也以torch.nn.DataParallel为背景进行说明。


What is Synchronized BN:什么是同步的BN,具体同步哪些东西?

由开篇至今,CW 一直提到“同步”这两个字眼,那么到底是什么是同步的BN,具体同步的是什么东西呢?

同步是发生在各个GPU之间的,需要同步的东西必然是它们互不相同的东西,那到底是什么呢?或许你会说是它们拿到的数据,嗯,没错,但肯定不能把数据同步成一样的了,不然这就和单卡训练没差别了,浪费了多张卡的资源...

现在,聪明的你肯定已经知道了,需要同步的是每张卡上计算的统计量,即 BN 层用到的\mu(均值)和\sigma^2(方差),这样子每张卡对其拿到的数据进行归一化后的效果才能与单卡情况下对一个 batch 的数据归一化后的效果相当。

因此,同步的 BN,指的就是每张卡上对应的 BN 层,分别计算出相应的统计量\mu和\sigma^2,接着基于每张卡的计算结果计算出统一的 \mu和\sigma^2,然后相互进行同步,最后它们使用的都是同样的\mu和\sigma^2


How to implement:如何实现多卡同步的BN?

(1). 2次同步 vs 1次同步

我们已经知道,在前向反馈过程中各卡需要同步均值和方差,从而计算出全局的统计量,或许大家第一时间想到的方式是先同步各卡的均值,计算出全局的均值,然后同步给各卡,接着各卡同步计算方差...这种方式当然没错,但是需要进行2次同步,而同步是需要消耗资源并且影响模型训练速度的,那么,是否能够仅用1次同步呢?

全局的均值很容易通过同步计算得出,因此我们来看看方差的计算:

方差的计算,其中m为各GPU卡拿到的数据批次大小(\frac{BatchSize}{nGPU})。

由上可知,每张卡计算出\sum_{i=1}^mx_{i}和\sum_{i=1}^mx_{i}^2,然后进行同步求和,即可计算出全局的方差。同时,全局的均值可通过各卡的\sum_{i=1}^mx_{i}同步求和得到,这样,仅通过1次同步,便可完成全局均值及方差的计算。

1次同步完成全局统计量的计算

(2). 介绍nn.DataParallel的前向反馈

熟悉 pytorch 的朋友们应该知道,在进行GPU多卡训练的场景中,通常会使用nn.DataParallel来包装网络模型,它会将模型在每张卡上面都复制一份,从而实现并行训练。这里我自定义了一个类继承nn.DataParallel,用它来包装SyncBN,并且重载了nn.DataParallel的部分操作,因此需要先简单说明下nn.DataParallel的前向反馈涉及到的一些操作。

nn.DataParallel的使用,其中DEV_IDS是可用的各GPU卡的id,模型会被复制到这些id对应的各个GPU上,DEV是主卡,最终反向传播的梯度会被汇聚到主卡统一计算。

先来看看nn.DataParallel的前向反馈方法的源码:

nn.DataParallel.forward

其中,主要涉及调用了以下4个方法:

1. scatter:将输入数据及参数均分到每张卡上;

2. replicate:将模型在每张卡上复制一份(注意,卡上必须有scatter分割的数据存在!);

3. parallel_apply:每张卡并行计算结果,这里会调用被包装的具体模型的前向反馈操作(在我们这里就是会调用 SyncBN 的前向反馈方法);

4. gather:将每张卡的计算结果统一汇聚到主卡。

注意,我们的关键在于重载replicate方法,原生的该方法只是将模型在每张卡上复制一份,并且没有建立起联系,而我们的 SyncBN 是需要进行同步的,因此需要重载该方法,让各张卡上的SyncBN 通过某种数据结构和同步机制建立起联系

(3). 重载nn.DataParallel.replicate方法

在这里,可以设计一个继承nn.DataParallel的子类DataParallelWithCallBack,重载了replicate方法,子类的该方法先是调用父类的replicate方法,然后调用一个自定义的回调函数(这也是之所以命名为DataParallelWithCallBack的原因),该回调函数用于将各卡对应的 SyncBN 层关联起来,使得它们可以通过某种数据结构进行通信。

子类重载的replicate方法
自定义的回调函数,将各卡对应的Syn-BN层进行关联,其中DataParallelContext是一个自定义类,其中没有定义实质性的东西,作为一个上下文数据结构,实例化这个类的对象主要用于将各个卡上对应的Syn-BN层进行关联;_sync_replicas是在Syn-BN中定义的方法,在该方法中其余子卡上的Syn-BN层会向主卡进行注册,使得主卡能够通过某种数据结构和各卡进行通信。

(4). Syn-BN的同步注册机制

由上可知,我们需要在 SyncBN 中实现一个用于同步的注册方法,SyncBN 中还需要设置一个用于管理同步的对象(下图中的 _sync_master),这个对象有一个注册方法,可将子卡注册到其主卡。

在 SyncBN 的方法中,若是主卡,则将上下文管理器的 sync_master 属性设置为这个管理同步的对象(_sync_master);否则,则调用上下文对象的同步管理对象的注册方法,将该卡向其主卡进行注册。

Syn-BN的同步注册机制
主卡进行同步管理的类中注册子卡的方法
主卡进行同步管理的类
子卡进行同步操作的类

(5). Syn-BN的前向反馈

如果你认真看完了以上部分,相信这部分你也知道大致是怎样一个流程了。

首先,每张卡上的 SyncBN 各自计算出 mini-batch 的和\Sigma x以及平方和\Sigma x^2,然后主卡上的 SyncBN 收集来自各个子卡的计算结果,从而计算出全局的均值和方差,接着发放回各个子卡,最后各子卡的 SyncBN 收到来自主卡返回的计算结果各自进行归一化(和缩放平移)操作。当然,主卡上的 SyncBN 计算出全局统计量后就可以进行它的归一化(和缩放平移)操作了。

Syn-BN前向反馈(主卡)
Syn-BN前向反馈(子卡)

#最后

在同步过程中,还涉及线程和条件对象的使用,这里就不展开叙述了,感兴趣的朋友可以到SyncBN源码上瞄瞄。另外,在信息同步这部分,还可以设计其它方式进行优化,如果你有更好的意见,还请积极反馈,CW热烈欢迎!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
禁止转载,如需转载请通过简信或评论联系作者。
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,015评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,262评论 1 292
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,727评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,986评论 0 205
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,363评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,610评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,871评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,582评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,297评论 1 242
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,551评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,053评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,385评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,035评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,079评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,841评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,648评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,550评论 2 270