生成对抗网络解读

摘要:生成对抗网络( Generative Adversarial Networks, GAN)是通过对抗训练的方式来使得生成网络产生的样本服从真实数据分布。在生成对抗网络中,有两个网络进行对抗训练。一个是判别器,目标是尽量准确地判断一个样本是来自于真实数据还是由生成器产生;另一个是生成器,目标是尽量生成判别网络无法区分来源的样本 。两者交替训练,当判别器无法判断一个样本是真实数据还是生成数据时,生成器即达到收敛状态。以上是对生成对抗网络的简单描述,本文将对生成对抗网络的内在原理以及相应的优化机制进行介绍。

文章概览

  • 概率生成模型

  • 生成对抗网络

    • 生成对抗网络的理论解释
    • 生成对抗网络的求解过程
  • 生成对抗网络的优化

    • fGAN
    • WGAN
  • 生成对抗网络的实现

    • GAN
    • CGAN
    • WGAN

概率生成模型

  概率生成模型,简称生成模型,是指一系列用于随机生成可观测数据的模型。假设在一个连续或离散的空间\chi中,存在一个随机向量X服从一个未知的数据分布P_{data}(x)x \in \chi。生成模型是根据一些可观测的样本x^1, x^2, ..., x^m来学习m一个参数化模型P_G(\theta;x)来近似未知分布,并可以用这P_{data}(x)个模型来生成一些样本,使得生成的样本和真实的样本尽可能的相似。对于一个低维空间中的简单分布而言,我们可以采用最大似然估计的方法来对p_\theta(x)进行求解。假设我们要统计全国人民的年均收入的分布情况,如果我们对每个P_{data}(x)样本都进行统计,这将消耗大量的人力物力。为了得到近似准确的收入分布情况,我们可以先假设其服从高斯分布,我们比如选取某个P_G(x;\theta)城市的人口收入x^1, x^2, ..., x^m,作为我们的观察样本结果,然后通过最大似然估计来计算上述假设中的高斯分布的参数。
L=\prod_{i=1}^{m} P_{G}\left(x^{i} ; \theta\right)

{\theta}^*=\arg \max _{\theta} \sum_{i=1}^{m} \log P_{G}\left(x^{i} ; \theta\right)

由于P_G(x;\theta)服从高斯分布,我们将其带入即可求得最终的近似的分布情况。下面我们对上述过程进行一些拓展,我们从P_{data}(x)尽可能采样更多的数据,此时可以得到
{\theta}^*=\arg \max _{\theta} \sum_{i=1}^{m} \log P_{G}\left(x^{i} ; \theta\right)\approx \arg \max _{\theta} E_{x \sim P_{\text {data }}}\left[\log P_{G}(x ; \theta)\right]
对该式进行一些变换,可以得到
{\theta}^*=\arg \min _{\theta} K L\left(P_{\text {data }} \| P_{G}\right)
  由此可以看出,最大似然估计的过程其实就是最小化P_{data}(x)分布和P_G分布之间KL散度的过程。从本质上讲,所有的生成模型的问题都可以转换成最小化P_{data}(x)分布和P_G分布之间距离的问题,KL散度只是其中一种度量方式。

  如上所述,对于低维空间的简单分布而言,我们可以显式的假设样本服从某种类型的分布,然后通过极大似然估计来进行求解。但是对于高维空间的复杂分布而言,我们无法假设样本的分布类型,因此无法采用极大似然估计来进行求解,生成对抗网络即属于这样一类生成模型。

生成对抗网络

生成对抗网络的理论解释

  在生成对抗网络中,我们假设低维空间中样本z服从标准类型分布,利用神经网络可以构造一个映射函数G(即生成器)将z映射到真实样本空间。我们希望映射函数G能够使得P_G(x)分布尽可能接近P_{data}(x)分布,即P_GP_{data}之间的距离越小越好:
G^{*}=\arg \min _{G} {\operatorname{Div}}\left(P_{G}, P_{\text {data }}\right)
由于P_GP_{data}的分布都是未知的,所以无法直接求解P_GP_{data}之间的距离。生成对抗网络借助判别器来解决这一问题。首先我们分别从P_GP_{data}中取样,利用取出的样本训练一个判别器:我们希望当输入样本为P_{data}时,判别器会给出一个较高的分数;当输入样本为P_G时,判别器会给出一个较低的分数。例如,我们可以将判别器的目标函数定义成以下形式(与二分类的目标函数一致,即交叉熵):
V(G, D)=E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))]
我们希望得到这样一个判别器(G固定):
D^{*}=\arg \max _{D} V(D, G)
从本质上来看,\max _{D} V(D, G)即表示P_GP_{data}之间的JS散度(具体推导参见李宏毅老师的课程),即:
\max _{D} V(G, D)=V\left(G, D^{*}\right)=-2 \log 2+2 J S D\left(P_{\text {data }} \| P_{G}\right)

D^{*}(x)=\frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_{G}(x)}

因此通过构建判别器可以度量P_GP_{data}之间的距离,所以G^*可以表示为:
G^{*}=\arg \min _{G} \max _{D} V(G, D)

生成对抗网络的求解过程

G^*的求解过程大致如下:

  • 初始化生成器G和判别器D
  • 迭代训练
    • 固定生成器G,更新判别器D的参数
    • 固定生成器D,更新判别器G的参数
gan

对上述算法过程进行几点说明:

  • 在之前的描述中,V(D,G)表示的是目标函数的期望,但在实际计算过程中是通过采样平均的方式来逼近其期望值。
  • 判别器的训练需要重复k次的原因是希望能尽可能使得V(D,G)接近最大值,这样才能满足"\max _{D} V(D, G)即表示P_GP_{data}之间的JS散度"这一假设。
  • 在更新生成器参数时,\frac{1}{m} \sum_{i=1}^{m} \log D\left(x^{i}\right)这一项可以忽略,因为D固定,其相当于一个常数项。
  • 在更新生成器参数时,我们使用\tilde{V}=\frac{1}{m} \sum_{i=1}^{m} -\log \left(D\left(G\left(z^{i}\right)\right)\right)代替\tilde{V}=\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(G\left(z^{i}\right)\right)\right),这样做的目的是加速训练过程。

生成对抗网络的优化

fGAN

  通过上面的分析我们可以知道,构建生成模型需要解决的关键问题是最小化P_GP_{data}之间的距离,这就涉及到如何对P_GP_{data}之间的距离进行度量。在上述GAN的分析中,我们通过构建一个判别器来对P_GP_{data}之间的距离进行度量,其中采用的目标函数为:
V(G, D)=E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))]\
通过证明可知,V(G, D)其实度量的是P_GP_{data}之间的JS散度。如果我们希望采用其他方式来衡量两个分布之间的距离,则需要对判别器的目标函数进行修改。根据论文fGAN,可以将判别器的目标函数定义成如下形式:
D_{f^*}\left(P_{\text {data }} \| P_{G}\right)=\max _{\mathrm{D}}\left\{E_{x \sim P_{\text {data }}}[D(x)]-E_{x \sim P_{G}}\left[f^{*}(D(x))\right]\right\}
G^*可以表示为:
G^{*}=\arg \min _{G} D_{f^*}\left(P_{\text {data }} \| P_{G}\right)
f^*取不同表达式时,即表示不同的距离度量方式。

fgan

f^*(t)=-log(1-exp(t))D(x)log,代入D_{f^*}\left(P_{\text {data }} \| P_{G}\right)即可得到V(G,D)

WGAN

  自2014年Goodfellow提出以来,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。针对这些问题,Martin Arjovsky进行了严密的理论分析,并提出了解决方案,即WGAN(WGAN的详细解读可参考这篇博客)。

  • 判别器越好,生成器梯度消失越严重。根据上面的分析可知,当判别器训练到最优时,\max _{D} V(D, G)衡量的是P_GP_{data}之间的JS散度。问题就出在这个JS散度上,我们希望如果两个分布之间越接近它们的JS散度越小,通过优化JS散度就能将P_G拉向P_{adta}。这个希望在两个分布有所重叠的时候是成立的,但是如果两个分布完全没有重叠的部分,或者它们重叠的部分可忽略,J S D\left(P_{\text {data }} \| P_{G}\right)=log2。在训练过程中,P_GP_{data}都是通过采样得到的,在高维空间中两者之间几乎不存在交集,从而导致\max _{D} V(D, G)接近于0,生成器因此也无法得到有效训练。

  • 最小化生成器loss函数E_{x \sim P_{G}}[\log (1-D(x)),会等价于最小化一个不合理的距离衡量,导致两个问题,一是梯度不稳定,二是collapse mode即多样性不足。假设当前的判别器最优,经过推导可以得到下面等式:
    E_{x \sim P_{G}}[\log (1-D(x))=KL\left(P_{\text {G }} \| P_{data}\right)-2 J S D\left(P_{\text {data }} \| P_{G}\right)
    这个等价最小化目标存在两个严重的问题。第一是它同时要最小化生成分布与真实分布的KL散度,却又要最大化两者的JS散度,一个要拉近,一个却要推远!这在直观上非常荒谬,在数值上则会导致梯度不稳定,这是后面那个JS散度项的毛病。第二,即便是前面那个正常的KL散度项也有毛病,因为KL散度不是一个对称的衡量KL\left(P_{\text {G }} \| P_{data}\right)KL\left(P_{\text {data}} \| P_{G}\right)是有差别的。

  • 原始GAN的主要问题就出在距离度量方式上面,Martin Arjovsky提出利用Wasserstein距离来进行衡量。Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。
    W(G, D)=E_{x \sim P_{\text {data }}}[D(x)]-E_{x \sim P_{G}}[D(x)]\

wgan

  由以上算法可以看出,WGAN与原始的GAN在算法实现方面只有四处不同:(1)判别器最后一层去掉sigmoid;(2)生成器和判别器的loss不取log;(3)每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c;(4)不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行。

生成对抗网络的实现

  本文实现了几种常见的生成对抗网络模型,包括原始GAN、CGAN、WGAN、DCGAN。开发环境为jupyter lab,所使用的深度学习框架为pytorch,并结合tensorboard动态观测生成器的训练效果,具体代码请参考我的github。

GAN

real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)

# 训练判别器
d_real = D(real_img)
d_real_loss = criterion(d_real, real_label)

z = torch.normal(0, 1, (batch_size, latent))
fake_img = G(z)
d_fake = D(fake_img)
d_fake_loss = criterion(d_fake, fake_label)

optimizer_D.zero_grad()
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()

# 训练生成器
fake_img = G(z)
d_fake = D(fake_img)
g_loss = criterion(d_fake, real_label)

optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()

CGAN

real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)

z = torch.normal(0, 1, (batch_size, latent))

# 训练判别器
d_real = D(real_img, label)
d_real_loss = criterion(d_real, real_label)

fake_img = G(z, label)
d_fake = D(fake_img, label)
d_fake_loss = criterion(d_fake, fake_label)

optimizer_D.zero_grad()
d_loss = (d_real_loss + d_fake_loss)
d_loss.backward()
optimizer_D.step()

# 训练生成器
fake_img = G(z, label)
d_fake = D(fake_img, label)
g_loss = criterion(d_fake, real_label)

optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()

WGAN

# 训练判别器
d_real = D(real_img)
#d_real_loss = criterion(d_real, real_label)
d_real_loss = d_real

z = torch.normal(0, 1, (batch_size, latent))
fake_img = G(z)
d_fake = D(fake_img)
#d_fake_loss = criterion(d_fake, fake_label)
d_fake_loss = d_fake

optimizer_D.zero_grad()
#d_loss = d_real_loss + d_fake_loss
d_loss = torch.mean(d_fake_loss) - torch.mean(d_real_loss)
d_loss.backward()
optimizer_D.step()

for p in D.parameters():
    p.data.clamp_(-clip_value, clip_value)
# 训练生成器
fake_img = G(z)
d_fake = D(fake_img)
#g_loss = criterion(d_fake, real_label)
g_loss = - torch.mean(d_fake)

optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,560评论 4 361
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,104评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,297评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,869评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,275评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,563评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,833评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,543评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,245评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,512评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,011评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,359评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,006评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,062评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,825评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,590评论 2 273
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,501评论 2 268

推荐阅读更多精彩内容