图注意力网络(GAT,GraphAttentionNetwork)

GAT(GRAPH ATTENTION NETWORKS)是一种使用了self attention机制图神经网络,该网络使用类似transformer里面self attention的方式计算图里面某个节点相对于每个邻接节点的注意力,将节点本身的特征和注意力特征concate起来作为该节点的特征,在此基础上进行节点的分类等任务。

图1:节点的特征由节点本身和直接相连的节点共同决定

下面是transformer self attention原理图:


图2:Transformer self attention

GAT使用了类似的流程计算节点的self attention,首先计算当前节点和每个邻接节点的注意力score,然后使用该score乘以每个节点的特征,累加起来并经过一个非线性映射,作为当前节点的特征。


图3:节点的特征计算

Attention score公式表示如下:
图4

图5

这里使用W矩阵将原始的特征映射到一个新的空间,a代表self attention的计算,如前面图2所示,这样计算出两个邻接节点的attention score,也就是Eij,然后对所有邻接节点的score进行softmax处理,得到归一化的attention score。
代码可以参考这个实现:https://github.com/gordicaleksa/pytorch-GAT
核心代码:

    def forward(self, data):
        in_nodes_features, connectivity_mask = data  
        num_of_nodes = in_nodes_features.shape[0]
        in_nodes_features = self.dropout(in_nodes_features)
        # V
        nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)

        nodes_features_proj = self.dropout(nodes_features_proj)  
        # Q、K
        scores_source = torch.sum((nodes_features_proj * self.scoring_fn_source), dim=-1, keepdim=True)
        scores_target = torch.sum((nodes_features_proj * self.scoring_fn_target), dim=-1, keepdim=True)

        scores_source = scores_source.transpose(0, 1)
        scores_target = scores_target.permute(1, 2, 0)
        # Q * K
        all_scores = self.leakyReLU(scores_source + scores_target)
        all_attention_coefficients = self.softmax(all_scores + connectivity_mask)
        # Q * K * V
        out_nodes_features = torch.bmm(all_attention_coefficients, nodes_features_proj.transpose(0, 1))

        out_nodes_features = out_nodes_features.permute(1, 0, 2)
        # in_nodes_features + out_nodes_features(attention)
        out_nodes_features = self.skip_concat_bias(all_attention_coefficients, in_nodes_features, out_nodes_features)
        return (out_nodes_features, connectivity_mask)

该GAT的实现也包含在了PYG库中,这个库涵盖了各种常见的图神经网络方面的论文算法实现。

推荐阅读更多精彩内容