论文 | NeurIPS2020 CrossTransformers:spatially-aware few-shot transfer

一 写在前面

未经允许,不得转载,谢谢~~~

嘿,好久不见,我要开始慢慢恢复科研论文笔记的更新啦~

今天分享的文章是做小样本图像识别的。

主要信息:

二 主要内容

2.1 相关背景

小样本图像识别的方法从整体上来看大概可以分成两个阶段:

  1. representation learning:获取到一个比较好的图像特征提取器;
  2. classifier:通过比对query images和support images进行query image的标签预测

文章首先总结了现有方法的共同点:

  1. 在representation leanring的学习上具有的一个共同点就是都会使用训练图片的类别标签做一个监督学习;
  2. 在classifier学习阶段具有的一个共同点是会将query和support图像之间的整体特征进行比较,例如ProtoNet就是将query的特征与support set中每个类中心的特征进行比较。

2.2 本文工作

文章首先支持现有方法的不足:

  1. 完全依靠类别标签进行特征学习的方式会导致只能学习到跟类别相关的信息,而忽略其他更加通用的特征表示;
  2. 在做图像比较的时候,图像中的一些重要objects和scenes通常是local的,直接用整体特征进行比较的效果不一定是最好的;

相对应地,文章提出了从两个方面进行优化:

  1. 针对第一个问题,提出引入自监督学习的方法SimCLR来获取更加通用的图像特征表示;
  2. 针对第二个问题,提出基于Transformer的新结构CrossTransformer,希望能够进行local信息的图像匹配;

三 方法介绍

文章是基于ProtoNet结构的,所以首先介绍下ProtoNet, 然后分别介绍以上两点novelty。

3.1 ProtoNet

ProtoNet算的上是小样本图像识别领域最flagship的工作了,这里只做个简单的介绍。

N-way-K-shot
给定一堆带标签可供参考的Support Images,具体表示为有N个类别,每个类别有K张带标注的图像,以及一个等待被分类的query image (query image的类别一定属于N个类别),我们需要根据support images预测出query image的类别标签。

key idea:
Protonet的想法非常直接但有效。即对每张图像都先用神经网络得到一个特征表示,然后对support set中每个类别c的所有特征取一个平均,作为这个类别的类中心。最后比较query feature跟各个类中心之间的距离,取最近的一个类别作为预测结果。

3.2 SSL with SimCLR

这里的想法也比较直接,就是觉得自监督学习得到的特征表示不仅对semantic敏感,而且对属于相同类别的不同图片也具有区分度,可以理解为只用class informaction进行监督学习得到的特征是class-level的,SSL学习到的是instance-level的,因此作者认为SSL学习到的特征泛化性会更好。

具体的做法也比较简单。为了区分原来的episode和现在用自监督的episode, 分别用MD-categorization episode以及SimCLR episode来表示它们。在训练的过程中随机转化50%的MD-categorization episode为SimCLR episode, 对SimCLR episode用SimCLR中的方法进行增强,然后对query image也进行增强,最后用各自对应的loss function进行优化。

:( 这边的具体细节感觉只看文章还不是特别清楚,可能需要感兴趣的同学可以自己看看他们的code

3.3 CrossTransformers

这部分都是基于Transformer构建的,如果之前完全不了解的话或许是会比较困难的,建议看看原文:https://arxiv.org/abs/1706.03762, 或者推荐一个我个人最推荐的blog:https://zhuanlan.zhihu.com/p/48508221

文章的主要框架图如下图所示:

第一张是文章原图,第二张是我在原图的基础上把各个重要部分对应的数据维度标注上去以及补充了额外内容的图,可以对照着看。

文章原图
带标记图

主要的pipeline包括以下几步:

  1. 首先看输入,给定最左边的一个query image x_q, 以及最上面的support set中类别为c的几个图像{x_1^c, x_2^c, ...}, 网络的目的是要获取到一个query-specific的类中心(不再是原始ProtoNet版本中直接取平均的方法)
  2. 首先注意到不管是对于query还是对于support images,都是先用一个\phi()得到图像的特征表示,这里文章中用的是ResNet,并且去掉了最后一个pooling层,所以得到的特征维度为R^{H`^ \times W^` \times D}
  3. 接下来就是基于query,key,value的attention操作。这里的query是指query image,而key和value都是指support sets。理解这一点对理解整个attention还挺重要的。

网络图中的query heads,key heads都是将输入特征从D维度映射到d_k维度,而value heads将输入特征从D维度映射到d_v维度。

具体地,(建议对着图看)

  • query heads将query特征从R^{H^` \times W^` \times D}维度映射到R^{H^` \times W^` \times d_k}维度(图中shi黄色的框框);
  • key heads将support特征从R^{H^` \times W^` \times D}维度映射到R^{H^` \times W^` \times d_k}维度(图中亮黄色的框框,左右两个表示的是一样的意思,看第一个就行了);
  • value heads将support特征从R^{H^` \times W^` \times D}映射到R^{H^` \times W^` \times d_v}维度(图中红色框框,也看其中一个就可)。
  1. 然后就是计算query和key之间的attention,我们还是只看一个query(shi黄色框)和一个support图像特征(第一个亮黄色框框),经过映射之后两个的特征维度都是R^{H^` \times W^` \times d_k},对于query中任意一个位置p和support中的任意一个位置m,特征维度都是d_k, 通过向量点乘的方法可以得到这2个点之间的attention值,图中小黑点在的位置。对每个HxW中的点都计算一次attention,最终就会得到一张query和一张support的attention mapa_1^c, 当然还做了一个softmax操作得到更新后的attention map\tilde{a_1^c}。对suppport中的多张图采取同样的操作就会得到多张attention map。

  2. 最后就是利用这些attention maps对support set中不同图像的vaule特征进行加权平均。这部分操作可以理解为,对于<query, support image i>, 对于HxW中的任意一个位置,都用其第i张attention map的值乘上对应第i个红色框框位置的value,最后把不同support images的结果值进行相加得到最终query-aligned prototype的特征表示,其维度为R^{H^` \times W^` \times d_v}

  3. 到这里为止我们获取到了query-aligned prototype R^{H^` \times W^` \times d_v}。 但是要做小样本预测到这里还没有完全完整,我把第二张图中把剩下的部分补上了。对于query image,其实也用value head做了一个映射,得到一个query image的value 特征表示,其维度为R^{H^` \times W^` \times d_v}, 跟prototype的维度是一样的,这样就可以比较这两者之间的距离,进而进行label预测了.

五 写在最后

我在写这个blog的时候,尽量避免了公式的出现,但可能有些地方解释的还是有些不好理解,尤其是crossTransformer部分涉及的符号略多,大家见谅啦。

这篇文章暂时介绍到这里,最后打个不那么相关的广告,我们做小样本视频分类的工作(AMeFu-Net)近期开源了,link: https://github.com/lovelyqian/AMeFu-Net,欢迎大家关注~