推荐系统之AFM模型原理以及代码实践

简介

本文要介绍的是由浙江大学联合新加坡国立大学提出的AFM模型。通过名字也可以看出,此模型又是基于FM模型的改进,其中A代表”Attention“,即AFM模型实际上是在FM模型中引入了注意力机制改进得来的。之所以要在FM模型中引入注意力机制,是因为传统的FM模型对所有的交叉特征都平等对待,即每个交叉特征的权重都是相同的(都为1)。而在实际应用中,不同交叉特征的重要程度往往是不一样的。如果”一视同仁“地对待所有的交叉特征,不考虑不同特征对结果的影响程度,事实上消解了大量有价值的信息。

关于FM模型,可以参考推荐系统之FM(因子分解机)模型原理以及代码实践

推荐系统中的注意力机制

这里再举个例子,说明一下注意力机制是如何在推荐系统中派上用场的。注意力机制基于假设——不同的交叉特征对结果的影响程度不同,以更直观的业务场景为例,用户对不同交叉特征的关注程度应该是不同的。举例来说,如果应用场景是预测一位男性用户是否会购买一款键盘的可能性,那么”性别=男”“购买历史包含鼠标“这一交叉特征,很可能比”性别=男”“年龄=30“这一交叉特征重要,模型应该投入更多的”注意力“在前面的特征上。正因如此,将注意力机制引入推荐系统中也显得理所当然了。

模型

在介绍AFM模型之前,先给出FM模型的方程:

FM模型方程
其中w_0代表全局偏置,w_i代表第i个特征的权重,\hat w_{ij}代表交叉特征x_ix_j的权重,它可以被表示为:\hat w_{ij} = v_i^Tv_j,其中v_i \in \mathbb R^k代表第i个特征的embedding向量,k代表embedding向量的维度。由于存在系数x_ix_j,因此FM只能考虑两个均不为0的特征之间的交互。我们注意到,在FM中,所有的交叉特征\hat w_{ij}都具有相同的权重系数1,这可能会导致FM模型的泛化能力的下降,这也正是AFM需要改进的地方。

Pair-wise 交互层

Pair-wise 交互层将m个向量扩展到m(m-1)/2个交叉向量,每个交叉向量都是通过对两个不同的向量进行内积来计算的。可以通过以下公式来描述:

即对任意两个不同时为0的特征x_i,x_j,计算它们的乘积以及它们对应的embedding向量的乘积,最后再相乘起来。m个向量之间一共可以产生m(m-1)/2个不同的交叉特征向量。其实这样做是为了在神经网络架构下模拟FM中的交叉特征项。然后我们可以将f_{PI}( \mathcal{ E })通过一个求和池化层,再通过一个全连接层映射成最终的预测分数:
如果我们令p=1,b=0,那么我们就能完全还原FM模型。

这里其实跟NFM模型的核心操作类似,具体可以参考推荐系统之NFM模型原理以及代码实践

Attention-based Pooling层

下面看一下作者是如何将注意力机制加入到FM模型中去的。其实也很简单,只是在上述的交叉特征项前加入了注意力分数权重a_{ij},具体如下:


其中a_{ij}就是交叉特征\hat w_{ij}的注意力分数,它可以被看做是在预测过程中\hat w_{ij}的重要程度。
为了估计a_{ij},一个很自然的想法就是通过直接最小化目标损失函数来估计,在技术上似乎也是可行的。但是,对于从未在训练集中出现的共同出现过的特征,那么意味着x_i*x_j=0,故对应的注意力分数a_{ij}不可能通过估计得到。
为了解决这个问题,作者提出了通过MLP来参数化注意力分数,作者称之为”注意力网络“,其定义如下:
其中要学习的模型参数就是特征交叉层到注意力网络全连接层的权重W,偏置向量b,以及全连接层到Softmax输出层的权重向量h。注意力网络将于整个模型一起参与到梯度反向传播的过程,得到最终的权重参数。

直观来看,注意力网络就是将交叉特征首先通过一个全连接层,接着通过Relu激活函数,再乘以权重参数h得到a'_{ij},接着再通过一个Softmax层,将其映射成注意力权重,此时有\sum a_{ij} = 1

AFM模型

下面给出完整的AFM框架图:

AFM框架
其中绿色的方框代表的就是”注意力网络“,绿色箭头代表的是计算顺序。得到了交叉特征之后,先通过”注意力网络“获得注意力分数a_{ij},然后再将交叉特征乘以注意力分数a_{ij},再求和,即求和池化操作,最后得到预测分数。AFM模型的整体方程为:
AFM方程
模型的总参数为:
\Theta= \left\{w_{0}, \left\{w_{i}\right\}_{i=1}^{n},\left\{\mathbf{v}_{i}\right\}_{i=1}^{n}, \mathbf{p}, \mathbf{W}, \mathbf{b}, \mathbf{h}\right\}

代码实践

模型部分:

import torch
import torch.nn as nn
from BaseModel.basemodel import BaseModel

class AFM(BaseModel):
    def __init__(self, config, dense_features_cols, sparse_features_cols):
        super(AFM, self).__init__(config)
        self.num_fields = config['num_fields']
        self.embed_dim = config['embed_dim']
        self.l2_reg_w = config['l2_reg_w']

        # 稠密和稀疏特征的数量
        self.num_dense_feature = dense_features_cols.__len__()
        self.num_sparse_feature = sparse_features_cols.__len__()

        # AFM的线性部分,对应 ∑W_i*X_i, 这里包含了稠密和稀疏特征
        self.linear_model = nn.Linear(self.num_dense_feature + self.num_sparse_feature, 1)

        # AFM的Embedding层,只是针对稀疏特征,有待改进。
        self.embedding_layers = nn.ModuleList([
            nn.Embedding(num_embeddings=feat_dim, embedding_dim=config['embed_dim'])
                for feat_dim in sparse_features_cols
        ])

        # Attention Network
        self.attention = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.projection = torch.nn.Linear(self.embed_dim, 1, bias=False)
        self.attention_dropout = nn.Dropout(config['dropout_rate'])

        # prediction layer
        self.predict_layer = torch.nn.Linear(self.embed_dim, 1)

    def forward(self, x):
        # 先区分出稀疏特征和稠密特征,这里是按照列来划分的,即所有的行都要进行筛选
        dense_input, sparse_inputs = x[:, :self.num_dense_feature], x[:, self.num_dense_feature:]
        sparse_inputs = sparse_inputs.long()

        # 求出线性部分
        linear_logit = self.linear_model(x)

        # 求出稀疏特征的embedding向量
        sparse_embeds = [self.embedding_layers[i](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])]
        sparse_embeds = torch.cat(sparse_embeds, axis=-1)
        sparse_embeds = sparse_embeds.view(-1, self.num_sparse_feature, self.embed_dim)

        # calculate inner product
        row, col = list(), list()
        for i in range(self.num_fields - 1):
            for j in range(i + 1, self.num_fields):
                row.append(i), col.append(j)
        p, q = sparse_embeds[:, row], sparse_embeds[:, col]
        inner_product = p * q

        # 通过Attention network得到注意力分数
        attention_scores = torch.relu(self.attention(inner_product))
        attention_scores = torch.softmax(self.projection(attention_scores), dim=1)

        # dim=1 按行求和
        attention_output = torch.sum(attention_scores * inner_product, dim=1)
        attention_output = self.attention_dropout(attention_output)

        # Prodict Layer
        # for regression problem with MSELoss
        y_pred = self.predict_layer(attention_output) + linear_logit
        # for classifier problem with LogLoss
        # y_pred = torch.sigmoid(y_pred)
        return y_pred

在criteo数据集上测试,测试代码如下:

import torch
from AFM.network import AFM
from DeepCrossing.trainer import Trainer
import torch.utils.data as Data
from Utils.criteo_loader import getTestData, getTrainData

afm_config = \
{
    'num_fields': 26, # 这里配置的只是稀疏特征的个数
    'embed_dim': 8, # 用于控制稀疏特征经过Embedding层后的稠密特征大小
    'seed': 1024,
    'l2_reg_w': 0.001,
    'dropout_rate': 0.1,
    'num_epoch': 200,
    'batch_size': 64,
    'lr': 1e-3,
    'l2_regularization': 1e-4,
    'device_id': 0,
    'use_cuda': False,
    'train_file': '../Data/criteo/processed_data/train_set.csv',
    'fea_file': '../Data/criteo/processed_data/fea_col.npy',
    'validate_file': '../Data/criteo/processed_data/val_set.csv',
    'test_file': '../Data/criteo/processed_data/test_set.csv',
    'model_name': '../TrainedModels/AFM.model'
}

if __name__ == "__main__":
    ####################################################################################
    # AFM 模型
    ####################################################################################
    training_data, training_label, dense_features_col, sparse_features_col = getTrainData(afm_config['train_file'], afm_config['fea_file'])
    train_dataset = Data.TensorDataset(torch.tensor(training_data).float(), torch.tensor(training_label).float())

    test_data = getTestData(afm_config['test_file'])
    test_dataset = Data.TensorDataset(torch.tensor(test_data).float())

    afm = AFM(afm_config, dense_features_cols=dense_features_col, sparse_features_cols=sparse_features_col)
    ####################################################################################
    # 模型训练阶段
    ####################################################################################
    # # 实例化模型训练器
    trainer = Trainer(model=afm, config=afm_config)
    # 训练
    trainer.train(train_dataset)
    # 保存模型
    trainer.save()

    ####################################################################################
    # 模型测试阶段
    ####################################################################################
    afm.eval()
    if afm_config['use_cuda']:
        afm.loadModel(map_location=lambda storage, loc: storage.cuda(afm_config['device_id']))
        afm = afm.cuda()
    else:
        afm.loadModel(map_location=torch.device('cpu'))

    y_pred_probs = afm(torch.tensor(test_data).float())
    y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
    print("Test Data CTR Predict...\n ", y_pred.view(-1))

点击率预估结果如下(预测用户会点击输出为1,反之为0):
测试结果

完整代码见:https://github.com/HeartbreakSurvivor/RsAlgorithms/tree/main/AFM

参考

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,736评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,167评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,442评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,902评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,302评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,573评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,847评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,562评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,260评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,531评论 2 245
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,021评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,367评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,016评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,068评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,827评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,610评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,514评论 2 269

推荐阅读更多精彩内容