分层级联Transformer!苏黎世联邦提出TransCNN: 显著降低了计算/空间复杂度!
导读
本文是苏黎世联邦理工-Luc Van Gool组联合南开大学和电子科技大学在Transformer上的最新工作,这项工作主要提出了一个新的分层级联多头自注意力模块(MHSA),通过分层级联的方式显著降低了计算/空间复杂度。H-MHSA 模块可轻松插入任何CNN架构中,并且可以通过反向传播进行训练。基于此,我们提出了一种新的骨干网络叫做TransCNN,它完美继承了CNN和Transformer的优点,与当前最具竞争力的基于CNN和Transformer的网络架构相比,TransCNN在计算/空间复杂度和准确率上实现了最先进的性能。
作者单位:苏黎世联邦理工(Luc Van Gool大佬组)、南开大学, 电子科大 论文:https://arxiv.org/abs/2106.03180 代码: https://github.com/yun-liu/TransCNN
摘要与介绍
模型架构
模型方法
接下来,我们首先回顾一下视觉Transformer。然后,我们将详细介绍本文提出的H-MHSA模块。
Revisit Vision Transformer
Transformer主要依赖MHSA来建模长期特征依赖关系,这里,我们用表示输入, 和分别代表Token的数量以及每个token的特征维度,然后,我们有the query :, the key: , the value: ,并且
在假设输入和输出具有相同维度的情况下,传统的MHSA可以通过下式计算:
其中表示近似归一化,将Softmax函数应用于矩阵的每一行。注意,为了简单起见,我们在这里省略了多个头的计算。矩阵乘积具体做法是首先计算每对token之间的相似度。然后,在所有token的组合基础上再派生获取得到每个新的token。在MHSA计算后,可以进一步添加残差连接以方便优化,如:
这里,和分别表示:最大池化和平均池化。进一步针对
最后,我们得到:
残差连接:这个过程将不断迭代,直到足够小。然后我们停止切分网格块。H-MHSA的最终输出为:
其中表示将注意力特征上采样到原始大小,是特征投影的权重矩阵。为最大迭代步数。通过这种方式,H-MHSA可以建立全局特征依赖关系,相当于传统的MHSA。很容易证明,在所有都相同的假设下,H-MHSA的计算复杂度为:
所以,我们显著降低了计算复杂度,即从降低到了,并且这里比小得多。同理,空间复杂度也得到了显著降低。
Experiments
本文创新点总结
4. 未来,神经结构搜索技术也可以应用于TransCNN进行参数配置优化。