INN实现深入理解

FrEIA:Framework for Easily Invertible Architectures

FrEIA 是实现 INN 的基础,可以理解为实现 INN 的最重要的工具包。其分为两个模块,每个模块又有很多重要的类,以下是对这些类及其方法的描述。

一、FrEIA.framework

framework 模块包含用于构建网络模型和推断节点在向前和向后方向上执行的顺序的逻辑。该模块包括四个类。

<1> Node类
一个 Node 即为一个 INN 基础构建块。

build_modules(verbose=True)
该方法返回该结点的输出结点的维度列表。会递归调用其输入结点的 build_modules 方法。使用这些信息来初始化该结点的 pytorch.nn.Module,即构建起基础块。

run_forward(op_list)
调用该方法来确定执行到该结点前的操作顺序。会递归调用其输入结点的 run_forward 方法。每个操作被追加到全局列表 op_list 中,操作的表示形式为 (node ID, input variable IDs, output variable IDs)。

run_backward(op_list)
该方法和 run_forward 相似,只是用来确定逆向执行到该结点前的操作顺序。调用该方法前须先调用 run_forward。

<2> InputNode类
InputNode 类是 Node 类的子类,是一种特殊的结点——输入结点,表示整个 INN 网络的输入数据(及逆向过程的输出数据)。

<3>OutputNode类
OutputNode 类也是 Node 类的子类,是一种特殊的结点——输出结点,表示整个 INN 网络的输出数据(及逆向过程的输入数据)。

<4>ReversibleGraphNet类
是 torch.nn.modules.module.Module 类的子类。这个类表示了 INN 本身。这个类的构造函数会确定 node list 中的输入结点和输出结点,并调用输出结点的 build_modules 方法以获得结点间的连接关系及维度,调用输出结点的 run_forward 方法获得正向过程的所有操作及涉及的变量;再调用输入结点的 run_backward 方法获得逆向过程的所有操作及涉及的变量。

二、FrEIA.modules

modules 模块中都是 torch.nn.Module 的子类,都是可逆的,可用于 ReversibleGraphNet 中 Node 的构建。

1. Coefficient functions

class FrEIA.modules.coeff_functs.F_conv(in_channels, channels, channels_hidden=None, stride=None, kernel_size=3, leaky_slope=0.1, batch_norm=False)
F_conv 类使用多层卷积网络实现 INN 基础构建块中的 s、t 转换。其本身是不可逆的。

class FrEIA.modules.coeff_functs.F_fully_connected(size_in, size, internal_size=None, dropout=0.0)
F_fully_connected 类使用包含四层全连接层的神经网络实现 INN 基础构建块中的 s、t 转换。其本身是不可逆的。

2. Coupling layers

F_* 类本身都是不可逆的,而 coupling_layers 层中的类就是赋予 INN 可逆能力的。

class FrEIA.modules.coupling_layers.rev_layer(dims_in, F_class=<class 'FrEIA.modules.coeff_functs.F_conv'>, F_args={})
当转换是 F_conv 时使用 rev_layer 类实现可逆。

class FrEIA.modules.coupling_layers.rev_multiplicative_layer(dims_in, F_class=<class 'FrEIA.modules.coeff_functs.F_fully_connected'>, F_args={}, clamp=5.0)
当转换是 F_fully_connected 时使用 rev_multiplicative_layer 实现可逆,思想是基于 RealNVP 基础构建块的双向转换。

class FrEIA.modules.coupling_layers.glow_coupling_layer(dims_in, F_class=<class 'FrEIA.modules.coeff_functs.F_fully_connected'>, F_args={}, clamp=5.0)
和 rev_multiplicative_layer 相似,只是思想是基于 Glow 基础构建块的双向转换。

3. Fixed transforms

这个模块的类也都是 torch.nn.Module 的子类,用于实现一些固定转换。在构建块之间插入置换层,用于随机打乱下一层的输入元素,这使得 u = [u1,u2] 的分割在每一层都不同,也因此促进了独立变量间的交互。

class FrEIA.modules.fixed_transforms.linear_transform(dims_in, M, b)
根据 y=Mx+b 做固定转换,其中 M 是可逆矩阵。其实现为:

class FrEIA.modules.fixed_transforms.permute_layer(dims_in, seed)
随机转换输入向量列顺序。其实现为:

4. Graph topology

用于对 INN 网络结构进行调整,如构建残差 INN 等。

class FrEIA.modules.graph_topology.cat_layer(dims_in, dim)
沿给定维度合并多个张量。

class FrEIA.modules.graph_topology.channel_merge_layer(dims_in)
沿着通道从两个独立的输入合并到一个输出(用于跳过连接等)。

class FrEIA.modules.graph_topology.channel_split_layer(dims_in)
沿通道拆分以产生两个单独的输出(用于跳过连接等)。

class FrEIA.modules.graph_topology.split_layer(dims_in, split_size_or_sections, dim)
沿给定维度拆分以生成具有给定大小的单独输出的列表。

5.Reshapes

class FrEIA.modules.reshapes.flattening_layer(dims_in)
正向过程是将 N 维张量拉平为 1 维的张量。

class FrEIA.modules.reshapes.haar_multiplex_layer(dims_in, order_by_wavelet=False)
使用 Haar wavelets 将通道按宽度和高度对半分为 4 个通道。

class FrEIA.modules.reshapes.haar_restore_layer(dims_in)
使用 Haar wavelets 将 4 个通道合并为一个通道。

class FrEIA.modules.reshapes.i_revnet_downsampling(dims_in)
i-RevNet 中使用的可逆空间 downsampling。

class FrEIA.modules.reshapes.i_revnet_upsampling(dims_in)
与 i_revnet_downsampling 相反的操作。

loss function 深入解析

INN 的 loss 分为两部分:前向训练的损失和后向训练的损失。前后向训练的简单过程如下:

对应的损失函数如下:

其中,L2 是一种求平方误差的损失函数;MMD 则用于求两个分布之间的差异。

  • 在 lossf 中,前半部分是用于求正向网络预测值 ygen 与通过已知正向模型计算而得的 ytrue 之间的平方误差,降低这部分损失是为了使 INN 正向模型更能拟合真实模型;后半部分损失则是计算真实隐藏信息的分布 ztrue 和服从标准正态分布的 z 之间的分布差异,降低这部分损失是为了将在正向过程中丢失的信息 z' 映射到服从标准正态分布的 latent-z 空间,这样我们通过对标准正态分布采样,加上观测值 y,即可得到一个 x。另外,在反向传播时,忽略 MMD 损失上的 y 的梯度,以便训练神经元学习从真实潜在分布到正态分布的映射而不妨碍正向模型 x→y 的训练,因此这种 MMD 损失的收敛确保了 z 与 y 的独立性。
  • 在 lossb 中,前半部分用于求逆向网络根据真实潜在信息 ztrue 和 ytrue 得到的预测值 xgen1 与 xtrue 的平方误差,降低这部分损失是为了让 INN 逆向模型更准确;后半部分损失是计算逆向网络根据标准正态分布 z 和 ytrue 得到的预测值 xgen2 与 xtrue 的分布差异,降低这部分 MMD 损失是为了确保根据标准正态分布中的采样点和 ytrue 得到的 x 的分布与其 xtrue 的真实分布看起来相似。而之所以给 MMD 损失增加了随训练轮数 n 逐渐增加的权重 ξ(n),是为了避免最初训练时 MMD 的大梯度使网络远离正确的解决方案。

重读 Analyzing Inverse Problems with INN

这篇文章前后读了很多次,这次重读,又有许多新收获。

INN 的基础构建块是 RealNVP 中的仿射耦合层;如果要构建深度可逆网络,可将基础构建块和残差训练思想结合,i-RevNet 就是深度可逆网络。基础构建块本身是可逆结构,因此其中的 s、t 转换不需要是可逆的,其可以是任意复杂的网络结构;更重要的是,构建块的可逆性允许我们同时为正向训练和反向训练应用损失函数,并可以从任一方向计算 s 和 t 的梯度。

至于如何将输入分割为两部分,目前采用的方法是随机(但为了可逆,必须是固定的)分割。但为了实现更好的训练结果,往往在基础构建块之间加入置换层,用于 shuffle 输入,这样是为了使得每次分割得到的 u1、u2 和上一构建块分割得到的不同。如果没有置换层,那么在训练过程中,u1、u2 都是相互独立的。

正向过程的信息丢失使得逆向过程存在模糊性,如果使用贝叶斯方法,用 p(x|y) 表示这种模糊性,也是可行的,但非常复杂。实际上我们希望网络学习完整的后验分布 p(x|y),我们可以尝试预测简单分布的拟合参数(如均值和方差),或者使用分布(而非固定值)来表示网络权重,甚至两者结合,但无论如何,这将限制我们预先选择一种固定的分布或分布族。 也可以使用 cGAN,但它难以训练的,并且经常遭遇难以检测的模式崩溃。

因此,INN 中引入 latent variable z,其捕获了从 x 到 y 的正向过程中丢失的信息,即 x 与 [y,z] 形成了双射。在正向训练过程中,我们可以得到真实潜在分布 ztrue,从 ztrue 中采样,再结合 y 值,即可得到一个固定的 x 值。即,p(x|y) 转换成一个固定函数 x=g(y,z)。不过,为了易于采样,我们需要将 ztrue 调整为一个简单的分布(如标准正态分布),这样我们就可以从简单分布中采样,结合 y 值得到 x,这样的准确率依然是很高的,因为简单分布和 ztrue 之间只是一个简单映射关系。至此看来,INN 和 cGAN 思想是很接近的,不过 INN 的可逆性使得其训练过程与 cGAN 不同,也有其特殊的优点。

INN 同时对正向和逆向过程进行训练,累积网络两端的损失项的梯度。 我们还可以在 x 上添加类似 GAN 的鉴别器损失,但到目前为止,在我们的应用程序中,MMD 是足够的,因此我们可以避免对抗性训练的麻烦。在大量训练数据的支持下,我们可以训练网络拟合正向模型,由于网络的可逆构造,我们可以免费得到逆向模型。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 151,688评论 1 330
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 64,559评论 1 273
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 101,749评论 0 226
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 42,581评论 0 191
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 50,741评论 3 271
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 39,684评论 1 192
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,122评论 2 292
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 29,847评论 0 182
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 33,441评论 0 228
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 29,939评论 2 232
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,333评论 1 242
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 27,783评论 2 236
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,275评论 3 220
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 25,830评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,444评论 0 180
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 34,553评论 2 249
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 34,618评论 2 249

推荐阅读更多精彩内容