ICCV2021 MIT-IBM沃森开源CrossViT:Transformer走向多分支、多尺度

详细信息如下:
  • 论文链接:https://arxiv.org/abs/2103.14899

  • 项目链接:https://github.com/IBM/CrossViT

导言:
与卷积神经网络相比,最近出现的视觉Transformer(ViT)在图像分类方面取得了很好的结果。受此启发,在本文中,作者研究了如何学习Transformer模型中的多尺度特征表示来进行图像分类 。为此,作者提出了一种双分支Transformer来组合不同大小的图像patch,以产生更强的图像特征。本文的方法用两个不同计算复杂度的独立分支来处理小patch的token和大patch的token,然后这些token通过attention机制进行多次的交互以更好的融合信息。
此外,为了减少计算量,作者开发了一个简单而有效的基于cross-attention的token融合模块。在每一个分支中,它使用单个token(即 [CLS] token)作为query,与其他分支交换信息。本文提出cross-attention的计算复杂度和显存消耗与输入特征大小呈线性关系 。实验结果表明,本文提出的CrossViT的性能优于其他基于Transformer和CNN的模型。例如,在ImageNet-1K数据集上,CrossViT比DeiT的准确率高了2%,但是FLOPs和模型参数增加的非常有限。

      01      

Motivation
Transformer使NLP任务中序列到序列建模的能力取得了很大的飞跃。Transformer在NLP中的巨大成功激发了其在计算机视觉领域的应用。在ViT之前的一些工作主要将Transformer中的Self-Attention和CNN进行结合。虽然这些结合CNN和Self-Attention方法达到了比较不错的性能,但与纯粹的基于Self-Attention的Transformer相比,它们在计算方面的可拓展性非常有限。
ViT使用一系列embedding式的图像patch作为标准Transformer的输入,这是第一个与CNN模型性能相当的无卷积Transformer网络。然而,ViT需要非常大数据集,如ImageNet21K和JFT300M来进行预训练。之后的DeiT表明,数据增强和模型正则化可以在较少的数据下训练高性能的ViT模型。在此之后,ViT就逐渐成为了CV任务中的主流模型之一。
在这项工作中,作者研究了如何学习Transformer模型中的多尺度特征表示来进行图像识别 。多尺度的特征已经在很多工作中证明了对于CV任务是有效的,但多尺度特征对视觉Transformer的潜在好处仍有待验证。受到多分支CNN架构的启发,作者提出了一个双分支Transformer来组合不同大小的图像patch,以产生更强的视觉特征用于图像分类 。
本文的方法用两个具有不同计算复杂度的独立分支来分别处理大patch的token和小patch的token,这些token多次融合以相互补充信息。本文的重点是研究并设计适合视觉Transformer的多尺度特征融合方法 。在本文中,作者通过一个有效的交叉注意模块来实现这一点,其中每个Transformer分支创建一个non-patch token(即 [CLS] token)作为代理,并通过attention机制与另一个分支交换信息。(由于这里只使用[CLS] token进行信息交互,所以这一步attention的计算复杂度是线性的,而非二次的。
在ImageNet-1K数据集上,CrossViT比DeiT的准确率高了2%,但是FLOPs和模型参数增加的非常有限(如上图所示)。

      02      

方法

2.1.Vision Transformer的概述

视觉Transformer(ViT)首先通过将图像按照一定的patch大小划分,然后将图像转换为一个patch的序列,并将每个patch线性投影成embedding。为了执行分类任务,还要在序列中添加了一个额外的分类token(CLS)。此外,由于Transformer编码器中的自注意是与位置无关的,而视觉应用高度需要位置信息,因此ViT向每个token添加了位置embedding,包括CLS token。
然后,将所有token输入到堆叠的Transformer编码器,最后使用CLS token进行分类。Transformer编码器由一系列块组成,其中每个块包含带有FNN和多头自注意(MSA)。FFN在隐藏层上包含具有扩展比为r的两层多层感知器,在第一个线性层后应用一个GELU非线性激活函数。在每个块之前应用Layer normalization(LN),在每个块之后应用残差连接。ViT的输入和第k个块的处理可以表示为:
其中,和分别为CLS和patch token,为位置embedding。N和C分别是patch token的数量和embedding的维度。
值得注意的是,ViT与CNN的一个非常不同的设计是CLS token。在CNN中,最终的embedding通常是通过对所有空间位置的特征进行平均,而ViT使用与每个Transformer编码器上的patch token相互作用的CLS token作为最终的embedding。因此,我们可以认为CLS token是一个总结所有patch token的代理,因此作者提出的模块是基于CLS token设计的双路径多尺度的ViT。

2.2.多尺度Vision Transformer

patch大小会影响ViT的准确性和复杂性;使用细粒度的patch大小,ViT性能更好,但会导致更高的FLOPs和显存消耗。例如,patch大小为16的ViT比patch大小为32的ViT好6%,但前者需要更多4×的FLOPs。受此启发,作者提出的方法是利用来自更细粒度的patch大小的优点,同时平衡复杂性。更具体地说,作者首先引入了一个双分支ViT,其中每个分支处理不同patch大小的特征,然后采用一个简单而有效的模块来融合分支之间的信息。
上图展示了交叉注意多尺度视觉Transformer( Cross-Attention Multi-Scale Vision Transformer,CrossViT)的网络结构。模型主要由K个多尺度Transformer编码器组成,其中每个编码器由两个分支组成:
  1. L-Branch :大分支利用粗粒度的patch大小(),这个分支有更多的Transformer编码器和更大的embedding维度。
  2. S-Branch :小分支对细粒度的patch大小()进行操作,这个分支具有更少的编码器和更小的embedding维度。
两个分支的输出特征在Cross-Attention中融合L次,利用末端的两个分支对CLS token进行预测。对于两个分支的每个token,作者还在多尺度Transformer编码器之前添加了一个可学习的位置embedding,以学习位置信息。

2.3.多尺度特征融合

有效的特征融合是学习多尺度特征表示的关键。在本文中,作者探索了四种策略,如上图所示。
设是分支i上的token序列(包括patch和CLS token),其中i可以是l或者s,分别代表大分支和小分支。和分别代表分支i的CLS和patch token。

All-Attention Fusion

All-Attention Fusion是简单地concat来自两个分支的所有token,而不考虑每个token的属性,然后通过自注意模块融合信息,如上图所示。但是这种方法的时间复杂度也是跟输入特征大小呈二次关系,因为所有的token都是通过自注意模块计算的。All-Attention Fusion方案的输出可以表示为:
其中,和是用来对齐特征通道维度的投影函数。

Class Token Fusion

CLS token可以看作是一个分支的抽象全局特征表示,因为在ViT中,它可以作为用来预测结果的最终embedding。因此,可以直接对两个分支的CLS token求和,如上图所示。这种方法计算上非常有效,因为只需要处理一个token。这个融合模块的输出可以表示为:

Pairwise Fusion

上图显示了这两个分支是如何成对融合的。基于patch token位于图像的空间位置,启发式融合方法是根据它们的空间位置将它们合并。然而,这两个分支处理不同大小的patch,因此具有不同数量的patch token。因此,作者首先执行插值操作来对齐空间大小,然后以成对的方式融合两个分支的patch token。另一方面,将这两个分支的CLS token分别融合。L-Branch和S-Branch的成对融合的输出可以表示为:

Cross-Attention Fusion

上图显示了本文提出的Cross-Attention Fusion的基本思想,其中融合涉及到一个分支的CLS token和另一个分支的patch token。此外,为了更有效地融合多尺度特征,作者首先利用每个分支的CLS token作为代理,在另一个分支的patch token之间交换信息,然后将其投影到自己的分支。
由于CLS token已经学习了自己分支中所有patch token之间的抽象信息,因此与另一个分支中的patch token的交互有助于融合不同尺度的信息。与其他分支token融合后,CLS token在下一层Transformer编码器上再次与自己的patch token交互,在这一步中,它又能够将来自另一个分支的学习信息传递给自己的patch token,以丰富每个patch token的表示。
大分支(L-Branch)的 cross-attention模块结构如上图所示。具体来说,对于分支l,它首先从s分支中收集patch token,并将自己的CLS token与l分支的patch token进行concat,表示如下:
其中,是用于特征对齐的投影矩阵。然后,该模块在和之间执行交叉注意(CA),其中CLS token是唯一的query,因为patch token的信息已经被融合到了CLS token中。在数学上,CA可以表示为:
由于在query中只使用CLS token,所以在交叉注意(CA)中生成attention map的计算复杂度和显存消耗是线性的,而不是像在full-attention中那样是二次的,这使得整个过程更高效。此外,就像Self-Attention一样,作者也在CA中使用多个head,并将其表示为MCA。在交叉注意后,没有使用前馈网络FFN。具体来说,CA的输出的计算如下:
与其他三种简单的启发式方法相比,交叉注意获得了最好的精度,同时对多尺度特征融合有效。

2.4.  CrossViT

基于上面提出的结构,作者提出了多个不同计算量和参数量的CrossViT模型,具体结构如上表所示。

      03      

实验

3.1. Main Results

Comparisons with DeiT

上表比较了CrossViT和相似计算量下的DeiT进行的比较,可以看出,相比于DeiT,CrossViT具有明显的性能优势。

Comparisons with SOTA Transformers

上表展示了CrossViT和SOTA模型的对比结果,与ViT-B相比,CrossViT-18†的准确率高了4.9%(77.9% vs 82.8%)。

Comparisons with CNN-based Models

上表比较了CrossViT和CNN模型的性能。可以看出,除了EfficientNet,CrossViT相比于大多数CNN结构还是具有性能上的优势的。

Transfer Learning

为了验证本文方法的泛化性能,作者在不同的数据集上做了实验,可以看出CrossViT在其他数据集上同样具有性能的优势。

3.2. Ablation Studies

Comparison of Different Fusion Schemes

上表显示了不同fusion方法的性能。可以看出Cross-Attention相比于其他方法具有性能上的优势。

Effect of Patch Sizes

与Baseline相比,上表中的A展示了Patch Size的影响,可以看出使用(12,16)的patch大小能够在更少的FLOPs下能够获得了更好的准确性。

Channel Width and Depth in S-branch

尽管cross-attention设计的比较轻量级,作者在实验中还是测试了使用更复杂的s分支的性能,如上表(B和C)所示。这两种模型都增加了FLOPs和参数,但没有提高精度。

Depth of Cross-Attention and Number of Multi-Scale Transformer Encoders

为了增加两个分支的融合频率,可以堆叠更多的交叉注意模块(L),也可以堆叠更多的多尺度Transformer编码器(K)。如上表D和E所示,可以看出过频繁融合并不能提供任何性能的改进。

      04      

总结
在本文中,作者提出了一种用于学习多尺度特征的双分支视觉Transformer——CrossViT,以提高图像分类的识别精度。为了有效地结合不同尺度的图像patch token,作者进一步开发了一种基于交叉注意的融合方法,能够在线性时间复杂度下,进行两个分支的有效信息交换。通过大量的实验,作者证明了CrossViT的性能优于目前的CNN和Transformer模型。

      05      

思考

在Vision Transformer领域中,除了这篇CrossViT,近期还有另一篇论文CrossFormer(CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention )同样研究了Vision Transformer中的多尺度特征。下面我们来比较一下这两篇论文的异同点。

CrossFormer

CrossViT

两篇文章的模型结构如上图所示,相同的是,这两篇文章都用了视觉特征的多尺度信息 ,来获得更加丰富和鲁棒的视觉特征,从而提升视觉任务的性能。
不同之处主要存在以下几点:
  1. CrossFormer是一个多阶段结构,不同的阶段具有不同的分辨率(这一点和CNN很像,空间维度不断缩小,通道维度不断增加);而CrossViT与Transformer更相似,在所有的Block中都保持了相同的分辨率和通道维度。

  2. CrossFormer的多尺度特征体现在每个Stage内,它的多尺度是由于CrossFormer中每个stage的CEL模块在进行patch embedding的时候进行了多尺度embedding造成的;CrossViT的多尺度特征体现在Transformer Block之前的patch embedding,它在patch embedding的时候生成了两个尺度的特征,然后CrossViT中每个Block在比较小的计算量下,进行了这两个尺度特征的信息交互。

  3. CrossFormer的多尺度信息会更加丰富一些(第1个stage有4个尺度的特征信息;第2/3/4个stage都有2个尺度的特征信息);CrossViT在整个计算过程中存在2个尺度的视觉信息(即文中的S-Branch和L-Branch的尺度)

作者介绍

研究领域:FightingCV公众号运营者,研究方向为多模态内容理解,专注于解决视觉模态和语言模态相结合的任务,促进Vision-Language模型的实地应用。

知乎/公众号:FightingCV

END

欢迎加入「图像分类交流群👇备注:IC

(0)

相关推荐