自注意力真的是Transformer的必杀技吗?MSRA否认三连,并反手给你扔来一个sMLPNet
极市导读
本文构建了一种Attention-free、基于MLP的sMLPNet,主要将MLP模块中的token-mixing替换为稀疏MLP(sparse MLP, sMLP)模块,sMLPNet仅需24M参数即可在ImageNet数据及上取得81.9%top1精度,优于同等大小的CNN与Vision Transformer。>>加入极市CV技术交流群,走在计算机视觉的最前沿
论文链接:https://arxiv.org/pdf/2109.05422.pdf
前言
自今年5月份MLP-Mixer横冲出世以来,4个月的时间里出来了20+MLP相关的论文,从MLP-Mixer、ResMLP、gMLP到A2-MLP、CCS,再到ViP,MLP相关的结构好像一下被探索到头,自ViP之后的MLP相关的paper大多都借鉴了ViP的思想,此外再小心翼翼的加点不一样的小改进。与此同时,优异的Vision Transformer也在尝试将其内在的自注意力替换为MLP,比如Swin Transformer的变种SwinMLP 。
本文所提出的sMLPNet的一个新颖点是:token-mixing部分同时进行了局部与全局依赖建模。局部依赖建模比较容易想到,DWConv即可;全局建模用则从CSWin那里借鉴了一些idea。两者一组合就取得了非常👍🏻的指标。
但是,sMLPNet中引入了BN与DWConv,因此就不能算作是纯MLP架构,可能这也是之前MLP类模型非常小心翼翼的原因吧,生怕影响“出身”(狗头)。
下面正式介绍本篇论文~
Abstract
本文对Transformer中的核心模块自注意力模块进行探索:它是否是Transformer在图像识别任务中取得优异性能的关键?
我们构建了一种Attention-free的、基于MLP的sMLPNet。具体来讲,我们将MLP模块中的token-mixing替换为稀疏MLP(sparse MLP, sMLP)模块 。对于2D图像tokens,sMLP沿轴向进行1DMLP,同时在行/列方向上分别进行参数共享。受益于稀疏连接与参数共享,sMLP模块可以大幅降低模型参数量与计算量,进而避免了干扰MLP类模型性能的“过拟合”问题。
sMLPNet仅需24M参数即可在ImageNet数据及上取得81.9%top1精度,优于同等大小的CNN与Vision Transformer;当参数量扩大到66M,sMLPNet取得了83.4%top1精度,具有与Swin Transformer相当精度。sMLPNet的成功说明:自注意力机制并非Transformer取得优异性能的关键所在。
Method
在这篇文章中,我们旨在回答“有没有可能设计一种无自注意力模块的高性能图像分类模型 ”。我们希望保留CNN网络的某些重要设计思想,同时引入一些受Transformer启发的新成分。因此,我们尝试了如下设计指导思想:
采用与ViT、MLP-Mixer以及Swin Transformer类似的架构以确保公平比较;
显式地为网络注入局部偏置;
探索无自注意力模块的全局依赖;
以金字塔架构执行多阶段处理。
Overall Architecture
上图给出了本文所设计网络架构示意图,类似于ViT、MLP-Mixer以及Swin Transformer,它以RGB图像作为输入,并通过块拆分模块拆分为非重叠块。在网络的第一部分,块尺寸为,每个块reshape为48维向量,然后通过线性层映射为C维嵌入。此时,特征表示为。注:MLP-Mixer中的块尺寸为,故sMLPNet的tokens数量是MLP-Mixer的16倍 。
整个网络包含四个阶段,除第一个阶段以线性嵌入层作为起始外,其他阶段以块合并层作为起始,块合并层起降低空间分辨率提升通道数的作用。 将所得到的的token送入到后续的token-mixing模块(见上图b)与channel-mixing模块。
从上面图示可以看到:token-mixing通过采用深度卷积引入了局部偏置信息 。与此同时,我们还尝试利用所提sMLP建模全局依赖。相比原始MLP,sMLP的权值共享与稀疏连接机制使其更难于过拟合,同时可以大幅降低计算量 。
而channel-mixing则采用了常规的方式,即FFN+GeLU。因为比较简单,故略。
Sparse MLP
我们设计了一种稀疏MLP以消除原始MLP的两个主要缺陷:
我们希望降低参数量以避免过拟合,尤其是当模型在ImageNet训练时;
我们希望降低计算量以促进金字塔架构的多阶段处理,尤其是当模型的输入tokens数量较大时。
上图给出了原始MLP与本文所提sMLP示意图,也就是说:在sMLP中,每个token仅与同行/列的token产生交互关系。下图给出了sMLP的实现架构示意图,可以看到:sMLP包含三个分支,除identity分支外的其他两个分别用于水平/垂直方向token交互 。
假设输入token表示为,在水平交互分支上,输入张量先reshape为,然后通过线性层进行信息聚合;以类似的方式进行垂直交互;最后将三个分支的结果融合。上述过程可以表示如下:
Model Configurations
我们构建了三个不同大小的模型:sMLPNet-T、sMLPNet-S以及sMLPNet-B以对标Swin-T、Swin-S以及Swin-B。注:FFN中的扩展参数。不同模型的超参信息如下:
sMLPNet-T: C=80, 层数=[2;8;14;2];
sMLPNet-S: C=96, 层数=[2;10;24;2];
sMLPNet-B: C=112, 层数=[2;10;24;2]
Experiments
上图给出了所提方案与其他CNN、Transformer以及MLP在ImageNet数据集上的参数量、FLOPs以性能对比,从中可以看到:
相比经典ResNet,RegNetY具有非常大的优势;
除ViT之外的Transformer模型具有与RegNetY相当的性能;
大多MLP模型仅取得了与ResNet相当的性能,其中Mixer-L出现了严重的过拟合问题,侧面说明了:为何sMLP可以通过降低参数量取得更优结果。
在上述方案中,Swin Transformer取得了最佳性能。
尽管所提sMLPNet属于MLP类模型,但其具有与Swin Transformer相当甚至更优的性能。
sMLPNet-T取得了81.9%top1精度,为现有FLOPs小于5B模型中性能最佳者;
sMLPNet-B取得了与Swin-B相当的精度,但模型小25%,FLOPs少近10%。
需要注意的是:sMLPNet从S~B扩增并未看到过拟合;
上述结果表明:无注意力模块同样可以取得SOTA性能,进一步说明,注意力机制并非Transformer模型取得了优异的秘密武器。
Ablation Study
上表对sMLPNet的局部与全局建模的作用进行对比,从中可以看到:
移除DWConv后,模型参数量减少0.1,FLOPs减少0.1B,但性能下降了0.7%。这说明:DWCov是一种非常有效的局部依赖建模方案。
移除Local后,模型进去了80.7%精度。注:这个地方的原文分析看的莫名其妙。
上表对比了sMLP模块在不同阶段的作用,可以看到:
完全移除sMLP后,模型性能为82%;
阶段3中的sMLP移除导致的性能下降最多,约0.8%,而其他阶段均为0.2%。
上表比较了不同融合方式的性能,可以看到:替换后Sum融合后,模型性能大幅下降,从83.1%下降到81.5%。这说明:Concat+FC的融合方式具有更好的性能表现 。
上表比较了sMLP中各分支的排列方式性能,可以看到:
Identity分支可以带来0.4%-0.5%性能提升;
并行处理要比串行处理具有更优的性能。
上表对比了金字塔结构多阶段处理的重要性,从中可以看到:
将sMLP替换为MLP后,模型性能从81.3%下降到77.8%;
将多阶段处理替换后单阶段处理后,模型性能进一步下降到76.8%。
文末思考
We notice that some concurrent Transformer-based models, such as CSWin, have obtained an even higher accuracy than sMLPNet...