论文阅读:GBDT能否被深度学习取代——TabNet

论文名称:《TabNet: Attentive Interpretable Tabular Learning》
论文地址:https://arxiv.org/abs/1908.07442
相关代码:https://github.com/dreamquark-ai/tabnet Pytorch版本(目前star:778)

《TabNet: Attentive Interpretable Tabular Learning》是google 2019年底的一篇论文,目前已更新到v5版本。其目标是使用深度学习注意力网络,构建具有可解释性的用于表格数据的模型;并且可以利用大量无标注数据,使用自监督学习的方法提高模型效果。

简介

深度学习网络在图像、文本、声音等领域都具有经典框架并取得很好的效果,但对于最常用的数据表(由类别和数值型数据组成)至今没有通用的深度学习经典框架。 在数据表领域更常见的方法是使用集成决策树,这主要是由于:它能在超平面上计算出数据切分边界,具有可解释性,且训练速度快;另一方面也源于之前的深度学习模型设计出的结构并不太适合表格数据,网络参数过多,没有很好地归纳。可能导致无法定位到最佳解决方案。

设想将深度学习网络引入表格数据处理,源于它的诸多优势,如:在数据量越多的情况下它的表现越优异,它可以构建end-to-end模型,将图片数据和表格数据结合起来;减少特征工程的难度(这也是处理表格数据最重要的问题),从数据流中学习,end-to-end使模型应用于更多的场景,如数据自适应(迁移学习),生成模型,半监督学习等等。

论文的主要贡献在于:
(1) 模型可直接使用表格数据,不需要预先处理;使用的基于梯度下降的优化方法,使它能方便地加入端到端(end-to-end)的模型中。
(2) 在每一个决策时间步,利用序列注意力模型选择重要的特征,学习到最突出的特征,使模型具有可解释性。这是一种基于实例的特征选择(对于每个实例选择的特征不同),且它的特征选择和推理使用同一框架。
(3) TabNet有两个明显优势,一方面是它在分类和回归中都表现出了与其它模型差不多的模型效果,另一方面,它具有局部可解释性(特征重要性和组合方法),和全局可解释性(特征对模型的贡献)。 (4) 对表格数据,使用无监督数据预训练(mask方法),可提高模型性能。

实现

关于具体实现的描述在论文第3-5页,看了一遍论文没看懂,又找了一篇翻译也没看懂,后来结合源码才看明白。下面按个人理解结合代码分析一下(与论文顺序不完全一致)。 以pytorch版本为例,核心网络结构的实现tab_network.py代码中。

输入数据

表格类型的输入数据一般是数值型或者类别型,数据值直接代入模型,而类别型可能涉及N种取值,为简化模型,TabNet使用了可训练的Embedding方式处理类别型数据,即把一个类别型特征转换成几维数值型特征,通过它们的组合来表征,具体实现见EmbeddingGenerator类,它将每个类别型数据映射到数值类型,具体的映射方法通过训练得到。

编码器:TabNetEncoder

从结构图可以看到,无论是用于填充缺失特征的无监督学习(左),还是用于实际决策的有监督学习(右),都使用编码器TabNet encoder先将输入特征编码;然后根据不同用途分别与decoder连接填充缺失特征,或与全连接层相连实现最终决策。

注:网络结构层层包含比较复杂,建议在回归示例的代码中使用print(clf.network)查看具体网络结构

编码器由多个FeatTransformer和多个AttentiveTransformer堆叠而成。编码器的输入是处理好的数值型特征,输出包括编码后的特征a[i]和供最终决策使用的数据d[i]。 图中的step表示决策步(时间步),它有点像决策树中的判断结点,每一步都接收输入的所有特征,并选择其中几个特征计算。调用者可设定使用多少个Step(一般是3-10,图中示例为2 step)。每个Step接收数据特征(图中黑线)作为输入,并使用上一步的输出(图中红线)对数据特征加权(决定哪些特征更加重要)。而每一步的输出通过累加的方式用于最终决策。

注:Transformer模型最早出现在NLP处理中,详见论文《Attention is you need》,它替代了传统的RNN模型,Attention,Encoder,Decoder是其核心技术。

文中使用的技术可算是Transformer模型在表格数据处理中的应用。而上图中的Transformer指的是大模型中的子模块。

特征处理模块Feature transformer

论文中展示了四层的Feature transformer,它包括两个在所有决策步之间共享的层,和两个(Shared across decision steps)与单个决策步相关的层(Decision step dependent)。每一层又由GLU,归一化层和全连接层组成。

GLU门控线性单元Gated linear units,它在全连接网络上加入了门控,其公式如下:

等式的前半部分为全连接层,后半部分使用门控方式决定哪些信息传向下一层。 Encoder图中Feature transformer后面带有split,split指Feature transformer的输出分为两部分,一部分供最终预测Output使用,写作d[i],另一部分继续向后传递a[i],供注意力模块Attentive transformer使用。

Feature Transformer采用了两种不同的模块:所有时间步共享Shared across decision steps(整个TabNetEncoder共用一份),只影响单步的Decision step dependent(内部创建)。从图中可以看到,先处理了共用部分,又进一步处理了单步相关模块,使用根号5用于保证模型稳定。

上图中的FC+BN+GLU相当于代码中的GLU_Layer,而虚线框对应GLU_Block。 最终,每一步的输出d[i]经过RELU()处理后累加,再经过一个全连接层FC作出决策Out。

注意力模块Attentive transformer

以投资数据为示例,输入的每条记录是用户的具体情况,在第一步(第一个红字框)选择了职业相关的特征,通过职业特征计算出的结果,传给下一步(第二个红字框)并对投资相关的特征进行加权,有点像决策树在每一层判断节点上选择特征,而在这里选择的是多个特征的组合作为决策依据。

注意力模块,即Encoder图中的Attentive部分,它的灵感来源于处理文字和图片时只关注输入中的部分数据。在数据表中可看作每一步只选择其中几维数据处理,即特征选择,使用此方法简化了模型参数和学习效率。

其具体实现的方法是使用全连接层,归一化层,和Sparsemax,最终输出mask(用M表示)用于对输入数据加权。M的计算方法如下:

其中h是可被训练的模型参数(FC和BN实现),a[i-1]是前一个层提供给Attentive transformer的输入,Sparsemax可视为Softmax的稀疏化版本。P[i-1]记录该特征在之前step中的使用程度(一般来说,如果之前用过,本次就不太可能再用),具体计算方法是:

其中i是决策步step,γ为松弛参数,如果它为1,则只要使用过,则会再被使用,如果γ较大,则特征可被多次使用。P[0]一般设置为全1。如果发现有些特征完全没用(通过无监督学习发现),可将其P[0]设置成0,以免浪费计算资源。

为了达到更好稀疏效果,还计算了稀疏正则化,并将每一步的值计入整体损失中:

其中ε作为参数传入模型,一般是个非常小的值。M越稀疏,Lsparse越小,Loss也越小;反之在大多数特征都冗余时,Loss较大;它的作用是在每个时间步上尽量关注更少的特征。

模型可解释性

从原理上看,mask可以描述单个实例单步的特征重要性,使模型具备局部可解释性,全局可解释性则需要通过系数组合每个单步的重要性,文中提出了系数的计算方法:

等式描述第b实例在第i个时间步对整体决策的贡献。

表格数据的自监督学习

数据缺失是表格数据需要面对的重要问题,利用遮蔽训练方法(一般称mask方法,此mask与上面attention中的mask不同),是深度学习在自然语言处理中的一种常用方法,它使用大量无标数据训练,故意遮蔽(masked)一些有效数据,然后通过训练模型弥补数据缺失,间接实现了数据插补。使模型在数据缺失的情况下也能很好地工作。特别是在标注数据较少的情况下,效果更加明显。

具体方法是使用编码器与解码器结合的方式,解码器TabNetDecoder用于将编码后的特征还原到原始数据表特征。调用者也可以设置多步step解码,其中每一步由一个Feature transforer和一个全连接层组成。

具体方法是,使用S∈{0, 1}掩码,使用(1-S)遮蔽部分数据后传给编码器再解码,最终将解码后的数据乘S得到被遮蔽的特征(预测值),然后通过被遮蔽部分的实际值与预测值的误差调整模型参数。具体方法如下:

这里使用了正则化,以避免各特征取值范围不同的问题。另外,在每次迭代时利用伯努利分布重新对S采样,以保证遮蔽各个特征。

时间步

使用决策树模型时,决策树的顶部一般是比较有辨识度的特征(对大多数实例重要的特征),预测时每个实例都使用同一决策树。 TabNet模型也在每个时间步选择部分特征运算,像是决策中的各个结点。不同的是,每一步可选择一个或多个特征组合,且可以根据不同实例在各个时间步选择不同的特征(该功能主要由Attention选择参数和step多步方法提供)。

从理论上说,相对于树模型对特征全局性评估,TabNet更善于评估单个实例的特征重要性。也因为如此,预测和归因可以使用了同一框架——它在决策过程中选择的特征,就是该实例的重要特征。

用法

论文的实验和附录部分在不同数据集上对比了树模型与TabNet模型的效果,总之,准确率不比树模型低,且模型也不大,下面记录了自己做的一些实验,谈谈主观感受。

安装

在网上可以找到基于pytorch和tensorflow的TabNet版本(下载地址见参考部分),以pytorch版本为例,为简安装方法如下:

$ pip install pytorch-tabnet 

注:它对pytorch,scipy,sklearn,matplotlib都有一定要求

示例

以pytorch版本为例,其网络结构实现在:

https://github.com/dreamquark-ai/tabnet/tree/develop/pytorch_tabnet.py
(调用逻辑:tab_model.py->abstract_model.py->tab_network.py)

如果不修改网络内部结构,仅使用它构建好的网络来训练模型和预测,非常简单,它封装成了类似于sklearn的调用方法,几乎不需要了解深度学习库的用法,调用API即可,回归例程,可参考:
https://github.com/dreamquark-ai/tabnet/blob/develop/regression_example.ipynb
例程中对比了TabNet与XGBoost的用法和效果。

使用效果

测试了代码中自带的回归示例,训练集大小(26072, 14),使用我笔记本上的cpu,一次迭代约一两秒,146次迭代后early_stop,相对于xgboost的百次迭代还是慢了很多。对测试集(3267, 14),tabnet预测速度0.15秒,而xgboost 0.012秒,虽说差了十倍,但还在可接受的范围内。又试了一下GPU的机器,也没有太明显的提速。个人感觉在小数据量(表格数据)处理中,是否使用GPU没那么重要。
在我的机器上,使用默认参数测试,xgboost在速度和准确率上还是有明显的优势。

参考

Attention注意力机制
https://www.jianshu.com/p/1012297aff38

GLU(Gated Linear Units)门控线性单元
https://blog.csdn.net/qq_32458499/article/details/81513720

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 151,511评论 1 330
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 64,495评论 1 273
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 101,595评论 0 225
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 42,558评论 0 190
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 50,715评论 3 270
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 39,672评论 1 192
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,112评论 2 291
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 29,837评论 0 181
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 33,417评论 0 228
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 29,928评论 2 232
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,316评论 1 242
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 27,773评论 2 234
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,253评论 3 220
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 25,827评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,440评论 0 180
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 34,523评论 2 249
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 34,583评论 2 249

推荐阅读更多精彩内容