扩散模型原理解析

去年写的文章,从notion的博客搬到这边来发一下(本来想搬到微信公众号的,但是那个格式真的反人类就作罢了),原文请到这里看mewimpetus.以后文章都会再这边先发。

引言

扩散模型是今年AI领域最热门的研究方向。由其引发的AI绘画的产业变革正在如火如荼的进行,大有淘汰一大票初中级画师的势头,目前主流的(诸如OpenAI的DALL-E 2;Google的ImageGen;以及已经商业化的MidJourney;注重二次元的NovelAI;开源引爆这波热潮的stable-diffusion)图像生成模型效果已经让人惊艳,若是再发展几年,它带来的影响将不可估量,可以说整个绘画产业正在经历着一场百年未有之大变局。而这些功能强大的绘画模型,无疑都与Denoising Diffusion Probabilistic Models 摆脱不了关系,它的原始论文由Google Brain在2020年发表。 这篇博文主要带大家一起来探究一下DDPM的工作原理和实现细节。

扩散模型的基本流程

其实扩散模型的基本思路同GAN以及VAE并无二致,都是试图从一个简单分布的随机噪声出发,经过一系列的转换,转变成类似于真实数据的数据样本。

它主要包含前向加噪声和反向去噪声两个过程:

  • 从真实的数据分布中随机采样一个图片,然后通过一个固定的过程逐步往上面添加高斯随机噪声,直到图片变成一个纯粹的噪声
  • 构建一个神经网络,去学习一个去噪的过程,从一个纯粹的噪声出发,逐步还原回一个真实的图像。

接下来我们用数学形式来表达上面的两个过程。

前向扩散

我们将真实数据的分布定义为q(x),然后可以从这个分布中随机采样一个”真图“ x_0 \sim q(x),于是我们就可以定义一个前向扩散的递推过程q(x_t|x_{t-1})为每个时间步t添加少量高斯噪声并执行T步。DDPM作者将q(x_t|x_{t-1})定义为这样一个条件高斯分布(其中的\{\beta_t \in (0, 1)\}_{t=1}^T是一个既定的递增表):

q(\mathbf{x}*t|\mathbf{x}*{t-1})=\mathcal{N}(\mathbf{x}*t;\sqrt{1-\beta_t}\mathbf{x}*{t-1},\beta_tI)

显然,当t-1时刻的图像为x_{t-1}的条件下,t时刻的图像X_t服从一个均值\mu_t=\sqrt{1-\beta_t}x_{t-1},方差\sigma^2=\beta_tI的各项同性高斯分布。我们再观察一下这个递推式,因为1-\beta\beta都小于1,显然x_t的均值会比x_t-1更加趋向于0,方差也更趋向于I,因此如果设计合适的\beta_t序列,最终的X_T将趋近于标准的高斯分布\mathcal{N}(0,I)。根据高斯分布的性质1:

如果X \sim \mathcal{N}(\mu, \sigma^2)ab都是实数,那么aX+b \sim \mathcal{N}(a\mu+b,(a\sigma)^2)

上述的条件高斯分布显然可以通过从标准高斯分布的线性变换得到,我们定义\epsilon_t\sim \mathcal{N}(0, I),那么只要让\mathbf{x}*t=\sqrt{\beta_t}\epsilon_t +\sqrt{1-\beta_t}\mathbf{x}*{t-1},那么第t个时间步的图像\mathbf{x}_t\sim \mathcal{N}(\mathbf{x}*t;\sqrt{1-\beta_t}\mathbf{x}*{t-1},\beta_tI)

为了更好的计算任意时刻t的条件分布,我们根据上面的递推式逐步推导到x_0,为了方便推导,我们令\alpha_t=1-\beta_t\overline\alpha_t=\prod^t_{i=1}\alpha_i则有了推导1

\begin{aligned}\mathbf{x}*t &= \sqrt{1-\alpha_t}\epsilon_t + \sqrt{\alpha_t}\mathbf{x}*{t-1} \\&=\sqrt{1-\alpha_t}\epsilon_t +\sqrt{\alpha_t}(\sqrt{1-\alpha_{t-1}}\epsilon_t+\sqrt{\alpha_{t-1}}\mathbf{x}*{t-2}) \\&= \sqrt{1-\alpha_t}\epsilon_t+\sqrt{\alpha_t (1-\alpha*{t-1})} \epsilon_t + \sqrt{\alpha_t\alpha_{t-1}}\mathbf{x}*{t-2} \\ &= \sqrt {1-\alpha_t\alpha*{t-1}}\epsilon_t + \sqrt{\alpha_t\alpha_{t-1}}\mathbf{x}_{t-2} \\&= ... \\&=\sqrt{1-\overline\alpha_t}\epsilon_t +\sqrt{\overline\alpha_t}\mathbf{x}_0\end{aligned}

上式中第3行到第4行的推导用到了上述的性质1,以及高斯分布的另一个性质2

如果X\sim\mathcal{N}(\mu_X,\sigma_X^2)Y \sim \mathcal{N}(\mu_Y,\sigma_Y^2) 是独立统计的高斯随机变量,那么,它们的和也满足高斯分布X + Y \sim \mathcal{N}(\mu_X +\mu_Y, \sigma^2_X+\sigma^2_Y)

性质1可知,\sqrt{1-\alpha_t}\epsilon_t \sim\mathcal{N}(0,(1-\alpha_t)I),而\sqrt{\alpha_t(1-\alpha_{t-1})}\epsilon_t \sim \mathcal{N}(0, \alpha_t(1-\alpha_{t-1})I) ,再根据性质2,就可得\sqrt{1-\alpha_t}\epsilon +\sqrt{\alpha_t(1-\alpha_{t-1})}\epsilon \sim \mathcal{N}(0,(1-\alpha_t\alpha_{t-1})I), 再根据性质1写回到多项式的形式即得到推导的结果。

基于这个最终的推导结果,因为\alpha_t是事先已经定义好的,我们只需要给出初始真实分布采样x_0,即可以计算出任何第t步的样本x_t ,而不需要每次都从x_1开始一步步计算。

反向去噪

有了前向的过程,我们反过来想,既然前向扩散是一个马尔可夫过程,那么它的逆过程显然也是马尔可夫过程,如果我们可以构造一个相反的条件分布q(x_{t-1}|x_t),那不就可以从最终的x_T开始一步步地去噪,从而反推回初始的x_0了吗? 但是我们并不知道反向条件高斯分布的均值和方差。不过,在这个深度学习的时代,我们可以从真实数据集X_0出发,通过前向过程生成一系列的x_0 \to x_T 的真实扩散序列,然后设计一个神经网络从这些序列中来近似学习一个分布p_\theta(x_{t-1}|x_t)使其接近真实的q(x_{t-1}|x_t),其中的\theta是这个神经网络需要学习的参数,于是从x_T变换到x_0的概率可以表示成:

p_\theta(\mathbf{x}*{T:0}) = p(\mathbf{x}*T)\prod*{t=1}^Tp*\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)

当我们前向过程所定义的\beta_t足够小时,反向过程也满足高斯分布,因此我们可以假设神经网络要学习的这个分布是高斯分布,这意味着它需要去学习其均值\mu_\theta 和方差\Sigma_\theta ,换成与上述前向过程相同的表示则有递推公式:

p_\theta(\mathbf{x}*{t-1}|\mathbf{x}*t) = \mathcal{N}(\mathbf{x}*{t-1};\mu*\theta(\mathbf{x}*t, t),\Sigma*\theta(\mathbf{x}_t, t))

借助这个公式,我们就可以完成去噪过程了,接下来的任务变成了如何训练这个神经网络。

如何训练

基本思路

不知大家又没有觉得这个加噪声和去噪声的过程和VAE的编码和解码的过程十分类似,那么是否可以从VAE的训练方式中得到一些启发呢?实际上作者就是这么想的。

图像前向扩散及反向去噪的过程,正向概率已知,反向的转移概率是我们希望得到的。

显然,如果直接使用\mathbf{x}*t{\mathbf{x}'}*{t}的对比误差会导致模型过拟合成AE一样的无生成能力的模型。因此,我们使用与VAE类似的变分推断的方法,希望网络输出的{\mathbf{x}'}*{t-1}尽量接近由真实\mathbf{x}*0变化而来的\mathbf{x}*{t-1}的分布,即最小化似然p*\theta(\mathbf{x}*{t-1}|x*{t})与真实的q(\mathbf{x}*{t-1}|\mathbf{x}*{t}, \mathbf{x}*0)D*{KL}(q(\mathbf{x}*{t-1}|\mathbf{x}*{t}, \mathbf{x}*0))||p*\theta(\mathbf{x}*{t-1}|\mathbf{x}*{t})。于是每一个时间步骤\{t \in [1,T]\}的误差可以定义为:

L_t=\left\{\begin{aligned}&-log~p_\theta(\mathbf{x}*0|\mathbf{x}*1) ~~~,when~ t=1\\&D*{KL}(q(\mathbf{x}*{t-1}|\mathbf{x}*{t}, \mathbf{x}*0))||p*\theta(\mathbf{x}*{t-1}|\mathbf{x}_{t}) ~ ~~,when~ t \in [2,T]\end{aligned}\right.

而当t=1时,因为q(\mathbf{x}*0)是确定的,因此可以忽略这部分,故而L_1 = -log~p*\theta(\mathbf{x}_0|\mathbf{x}_1),因此

于是整个去噪过程的误差就是: L=\sum^{T}_{t=1}L_t 。实际训练时,我们并没有使用整体的误差L,而是通过均匀随机选择t ,来最小化 L_t

目标函数

要直接计算上面的KL散度是困难的,但是正如前面所说的,q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) 是一个高斯分布,于是根据贝叶斯公式有:

\begin{aligned} q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) &= q(\mathbf{x}*t \vert \mathbf{x}*{t-1}, \mathbf{x}*0) \frac{ q(\mathbf{x}*{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}*t \vert \mathbf{x}*0) } \\ &\propto \exp \Big(-\frac{1}{2} \big(\frac{(\mathbf{x}*t - \sqrt{\alpha_t} \mathbf{x}*{t-1})^2}{\beta_t} + \frac{(\mathbf{x}*{t-1} - \sqrt{\bar{\alpha}*{t-1}} \mathbf{x}*0)^2}{1-\bar{\alpha}*{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}*t} \mathbf{x}*0)^2}{1-\bar{\alpha}*t} \big) \Big) \\ &= \exp \Big(-\frac{1}{2} \big(\frac{\mathbf{x}*t^2 - 2\sqrt{\alpha_t} \mathbf{x}*t {\color{blue}{\mathbf{x}*{t-1}}} + \alpha_t {\color{red}{\mathbf{x}*{t-1}^2}} }{\beta_t} + \frac{ {\color{red}{\mathbf{x}*{t-1}^2}} {- 2 \sqrt{\bar{\alpha}*{t-1}} \mathbf{x}*0} \color{blue}{\mathbf{x}*{t-1}}{+ \bar{\alpha}*{t-1} \mathbf{x}*0^2} }{1-\bar{\alpha}*{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}*t} \mathbf{x}*0)^2}{1-\bar{\alpha}*t} \big) \Big) \\ &= \exp\Big( -\frac{1}{2} \big( {\color{red}{(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}*{t-1}})} \mathbf{x}*{t-1}^2} - \color{blue}{(\frac{2\sqrt{\alpha_t}}{\beta_t} \mathbf{x}*t + \frac{2\sqrt{\bar{\alpha}*{t-1}}}{1 - \bar{\alpha}*{t-1}} \mathbf{x}*0)} \mathbf{x}*{t-1}{ + {\color{green}{C(\mathbf{x}_t, \mathbf{x}_0)} }\big) \Big)} \end{aligned}

其中\color{green}{C(\mathbf{x}_t,\mathbf{x}*0)} 代表所有剩余与\mathbf{x}*{t-1}无关的项。

根据高斯分布的基本方程:

与上述的推导结果位置依次对应可得其方差和均值为:

根据上面的推导1可得\mathbf{x}_0 = \frac{1}{\sqrt{\overline{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon_t}) ,带入上式可得:

\boldsymbol\mu(\mathbf{x}_t,\mathbf{x}_0)= \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar\alpha}}\epsilon_t)

最小化上述的KL散度,可以转化为计算神经网络的预测的均值方差与上述均值方差的L2损失:

\begin{aligned} L_t^{\Sigma} &= ||\Sigma(\mathbf{x}_t,\mathbf{x}*0)-\Sigma*\theta(\mathbf{x}_t,t)||^2 \\ L_t^{\mu} &=||\mu(\mathbf{x}_t,\mathbf{x}*0)-\mu*\theta(x_t,t))|| \\ &= ||\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar\alpha}}{\color{red}\epsilon_t}) - \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}*t - \frac{\beta_t}{\sqrt{1-\bar\alpha}}{\color{green}\epsilon*\theta(\mathbf{x}_t, t)})||^2 \end{aligned}

DDPM的论文作者在论文中说他使用一个固定的方差取得了差不多的效果,因此他的神经网络只去学习了均值,而把方差设置成了\beta_t 或者是\frac{1-\bar a_{t-1}}{1-\bar a_t}\beta_t,因此我们接下来的推导也只考虑均值。后来Improved diffusion models 这篇论文将其改进后就让神经网络同时去学习均值和方差了,有兴趣的同学可以自行去了解。

观察上面的L_t^{\mu},除了\epsilon_\theta,其余项均为固定值与\theta无关,于是我们不妨将神经网络的学习目标从高斯分布的均值转变为\epsilon_\theta ,即去预测每个事件步的噪声量而非高斯分布的均值,因此我们最终的目标函数就变成了:

\begin{aligned} L_t &= ||\epsilon - \epsilon_\theta(\mathbf{x}*t,t)||^2 \\ &=||\epsilon - \epsilon*\theta(\sqrt{1-\overline\alpha_t}\epsilon +\sqrt{\overline\alpha_t}\mathbf{x}_0,t)||^2 \end{aligned}

然后,整个训练算法便是这样一个过程:

  1. 从真实的复杂未知分布q(x)随机抽取一个样本x_0
  2. 1T均匀采样一个时间步t
  3. 从均值为**0**方差为I的标准高斯分布中随机采样一个\epsilon
  4. 计算随机梯度\nabla_\theta ||\epsilon - \epsilon_\theta(\sqrt{1-\overline\alpha_t}\epsilon +\sqrt{\overline\alpha_t}\mathbf{x}_0,t)||^2 ,并通过随机梯度下降优化\theta
  5. 重复上述过程直到收敛

采样生成

当上述的神经网络学习好\epsilon_\theta , 就可以计算出均值\mu_\theta = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}*t - \frac{\beta_t}{\sqrt{1-\bar\alpha}}\epsilon*\theta) ,于是我们就可以从一个随机高斯噪声\mathbf{x}*T \sim \mathcal{N}(0,I) ,通过条件去噪概率p*\theta(\mathbf{x}*{t-1}|\mathbf{x}*{t}) = \mathcal{N}(\mu_\theta(\mathbf{x}_t,t),\sigma_t^2) 进行采样生成,逐步从\mathbf{x}_T\mathbf{x}_0

具体来说Sampling是这样一个过程:

  1. 随机采样一个\mathbf{x}_T \sim \mathcal{N}(0,I)

  2. t=T,…,1 ,依次执行:

    z \sim \mathcal{N}(0,I) \\ \\ \mathbf{x}*{t-1} = \mu*\theta(\mathbf{x}_t,t) + \sigma_t z

  3. 返回最终的\mathbf{x}_0

网络结构

虽然有了训练的方案,但是如何来设计这个神经网络\epsilon_\theta才能让我们这个扩散和反扩散的过程取得较好的效果呢?DDPM的作者选择了U-Net ,并且在实验中取得了很好的效果。

这个用于学习\epsilon_\theta的U-Net网络十分复杂,由一系列的诸如下采样、上采样、残差、位置Embedding、ResNet/ConvNeXT block、注意力模块、Group Normalization等组件组合而成,为了让大家了解整个网络各个组件的具体结构和连接方式,我绘制了一个详细的网络图:

DDPM 反向去噪神经U-Net网络结构

根据这个图,我们可以用tensorflow或者pytorch非常轻松的实现这个网络。不过显然这个网络很大,特别是图片很大时占用的显存会很高,而且采样步骤多推理也很慢,因此后面有很多对于DDPM的改进,篇幅关系,关于对DDPM的改进我们下篇文章再讲。

参考资料

  1. What are Diffusion Models?
  2. Understanding Diffusion Models: A Unified Perspective
  3. Denoising Diffusion Probabilistic Models
  4. Understanding Variational Autoencoders (VAEs)
  5. 正态分布- 维基百科,自由的百科全书 - Wikipedia
  6. U-Net: Convolutional Networks for Biomedical Image Segmentation
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,560评论 4 361
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,104评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,297评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,869评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,275评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,563评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,833评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,543评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,245评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,512评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,011评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,359评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,006评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,062评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,825评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,590评论 2 273
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,501评论 2 268

推荐阅读更多精彩内容