ICCV 2021 | 最快视觉Transformer!Facebook提出LeViT:快速推理的视...

AI/CV重磅干货,第一时间送达

CVer

一个专注侃侃计算机视觉方向的公众号。计算机视觉、图像处理、机器学习、深度学习、C/C++、Python、诗和远方等。
204篇原创内容
公众号
本文转载自:集智书童

LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference

论文:https://arxiv.org/abs/2104.01136

代码(刚刚开源):

https://github.com/facebookresearch/LeViT

吸取CNN优点!LeViT:快速推理的视觉Transformer,在速度/准确性的权衡方面LeViT明显优于现有的CNN和视觉Transformer,比如ViT、DeiT等,而且top-1精度为80%的情况下LeViT比CPU上的EfficientNet快3.3倍。
作者单位:Facebook

1 简介

本文的工作利用了基于注意力体系结构中的最新发现,该体系结构在高度并行处理硬件上具有竞争力。作者从卷积神经网络的大量文献中重新评估了原理,以将其应用于Transformer,尤其是分辨率降低的激活图。同时作者还介绍了Attention bias,一种将位置信息集成到视觉Transformer中的新方法。

图1 LeViT性能对比

最终作者提出了LeVIT:一种用于快速推理的混合神经网络。考虑在不同的硬件平台上采用不同的效率衡量标准,以最好地反映各种应用场景。作者通过广泛的实验表明该方法适用于大多数体系结构。总体而言,在速度/准确性的权衡方面,LeViT明显优于现有的卷积网络和视觉Transformer。例如,在ImageNet Top-1精度为80%的情况下,LeViT比CPU上的EfficientNet快3.3倍。

相同计算复杂度的情况下Transformer为什么快?

大多数硬件加速器(gpu,TPUs)被优化以用来执行大型矩阵乘法。在Transformer中,注意力机制和MLP块主要依靠这些操作。相比之下,卷积需要复杂的数据访问模式,因此它们的操作通常受io约束。这些考虑对于我们探索速度/精度的权衡是很重要的。

本文主要贡献:

  1. 采用注意力机制作为下采样机制的multi-stage transformer 结构;

  2. 一种计算效率高的patch descriptor,可以减少第一层特征的数量;

  3. 使用Translation-invariant attention bias取代ViT中的位置嵌入;

  4. 为了提高给定计算时间的网络容量,作者重新设计了Attention-MLP Block。

2 LeViT的设计

2.1 LeViT设计原则

LeViT以ViT的架构和DeiT的训练方法为基础,合并了对卷积架构有用的组件。第1步是获得Compatible Representation。如果不考虑classification embedding的作用,ViT就是一个处理激活映射的Layer的堆叠。

实际上,中间“Token”嵌入可以看作是FCN体系结构中传统的C×H×W激活映射(BCHW格式)。因此,适用于激活映射(池、卷积)的操作可以应用于DeiT的中间表征。

LeViT优化了计算体系结构,不一定是为了最小化参数的数量。ResNet系列比VGG更高效的设计原则之一是在其前2个阶段使用相对较小的计算预算应用strong resolution reductions。当激活映射到达ResNet的第3阶段时,其分辨率已经缩小到足以将卷积应用于小的激活映射,从而降低了计算成本。

2.2 LeViT组件

1、Patch embedding

初步分析表明,在transformer组的输入上应用一个小卷积可以提高精度。因此在LeViT中作者选择对输入应用4层3×3卷积(stride2)来降低分辨率。channel的数量是C=3,32,64,128,256。

以上操作减少了对transformer下层的激活映射的输入,同时不丢失重要信息。LeViT-256的patch extractor用184 MFLOPs将图像形状(3,224,224)转换为(256,14,14)。作为比较,ResNet-18的前10层使用1042 MFLOPs执行相同的dimensionality reduction。

为什么在transformer组的输入上应用一个小卷积可以提高精度?

2、No classification token

为了使用BCHW张量形式,LeViT删除了classification token。类似于卷积网络,在最后一个激活映射上使用GAP来代替,这将产生一个用于分类器的embedding。在训练中进行蒸馏,作者分别训练分类和蒸馏的Head。在测试时,平均2个Head的输出。在实践中,LeViT可以使用BNC或BCHW张量格式。

3、Normalization layers and activations

ViT架构中的FC层相当于1x1卷积。ViT在每个注意点和MLP单元之前使用层归一化。对于LeViT,每次卷积之后都要进行BN操作。然后与residual connection连接起来的每个BN权重参数初始化为零。BN可以与之前的卷积合并来进行推理,这比层归一化有运行优势(例如,在EfficientNet B0上,这种融合将GPU的推理速度提高了2倍)。而DeiT使用GELU函数,而LeViT的非线性激活都是Hardswish。

class Linear_BN(torch.nn.Sequential):
    def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
        super().__init__()
        self.add_module('c', torch.nn.Linear(a, b, bias=False))
        bn = torch.nn.BatchNorm1d(b)
        torch.nn.init.constant_(bn.weight, bn_weight_init)
        torch.nn.init.constant_(bn.bias, 0)
        self.add_module('bn', bn)

global FLOPS_COUNTER
        output_points = resolution**2
        FLOPS_COUNTER += a * b * output_points

@torch.no_grad()
    def fuse(self):
        l, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = l.weight * w[:, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Linear(w.size(1), w.size(0))
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

def forward(self, x):
        l, bn = self._modules.values()
        x = l(x)
        return bn(x.flatten(0, 1)).reshape_as(x)

4、Multi-resolution pyramid

LeViT在transformer架构中集成了ResNet stage。在各个stage中,该体系结构类似于一个visual transformer:一个带有交替MLP和激活块的残差模块。下面是注意块的修改。

class Attention(torch.nn.Module):    def __init__(self, dim, key_dim, num_heads=8,                 attn_ratio=4,                 activation=None,                 resolution=14):        super().__init__()        self.num_heads = num_heads        self.scale = key_dim ** -0.5        self.key_dim = key_dim        self.nh_kd = nh_kd = key_dim * num_heads        self.d = int(attn_ratio * key_dim)        self.dh = int(attn_ratio * key_dim) * num_heads        self.attn_ratio = attn_ratio        h = self.dh + nh_kd * 2        self.qkv = Linear_BN(dim, h, resolution=resolution)        self.proj = torch.nn.Sequential(activation(), Linear_BN(            self.dh, dim, bn_weight_init=0, resolution=resolution))

        points = list(itertools.product(range(resolution), range(resolution)))        N = len(points)        attention_offsets = {}        idxs = []        for p1 in points:            for p2 in points:                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))                if offset not in attention_offsets:                    attention_offsets[offset] = len(attention_offsets)                idxs.append(attention_offsets[offset])        self.attention_biases = torch.nn.Parameter(            torch.zeros(num_heads, len(attention_offsets)))        self.register_buffer('attention_bias_idxs',                             torch.LongTensor(idxs).view(N, N))

        global FLOPS_COUNTER        #queries * keys        FLOPS_COUNTER += num_heads * (resolution**4) * key_dim        # softmax        FLOPS_COUNTER += num_heads * (resolution**4)        #attention * v        FLOPS_COUNTER += num_heads * self.d * (resolution**4)

    @torch.no_grad()    def train(self, mode=True):        super().train(mode)        if mode and hasattr(self, 'ab'):            del self.ab        else:            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):  # x (B,N,C)        B, N, C = x.shape        qkv = self.qkv(x)        q, k, v = qkv.view(B, N, self.num_heads, -                           1).split([self.key_dim, self.key_dim, self.d], dim=3)        q = q.permute(0, 2, 1, 3)        k = k.permute(0, 2, 1, 3)        v = v.permute(0, 2, 1, 3)

        attn = (            (q @ k.transpose(-2, -1)) * self.scale            +            (self.attention_biases[:, self.attention_bias_idxs]             if self.training else self.ab)        )        attn = attn.softmax(dim=-1)        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)        x = self.proj(x)        return x

5、Downsampling

在LeViT stage之间,一个缩小的注意块减少了激活映射的大小:在Q转换之前应用一个subsampling,然后传播到soft activation的输出。这将一个大小为的输入张量映射到一个大小为的输出张量。由于尺度的变化这个注意块的使用没有残差连接。同时为了防止信息丢失,这里将注意力头的数量设为。

class Subsample(torch.nn.Module):
    def __init__(self, stride, resolution):
        super().__init__()
        self.stride = stride
        self.resolution = resolution

def forward(self, x):
        B, N, C = x.shape
        x = x.view(B, self.resolution, self.resolution, C)[
            :, ::self.stride, ::self.stride].reshape(B, -1, C)
        return x

6、Attention bias instead of a positional embedding

在transformer架构中的位置嵌入是一个位置依赖可训练的向量,在将token嵌入输入到transformer块之前,将其添加到token嵌入。如果没有它,转换器输出将独立于输入标记的排列。位置嵌入的Ablations会导致分类精度的急剧下降。

然而,位置嵌入只包含在注意块序列的输入上。因此,由于位置编码对higher layer也很重要,所以它很可能仍然处于中间表示中。

因此,LeViT在每个注意块中提供位置信息,并在注意机制中明确地注入相对位置信息:只是在注意力图中添加了注意偏向。对于每个head ,每2个像素和之间的标量值计算方式为:

第一项是经典的注意力。第二个是translation-invariant attention bias。每个Head有H×W参数对应不同的像素偏移量。对称差异和鼓励用 flip invariance进行训练。

self.attention_biases = torch.nn.Parameter(            torch.zeros(num_heads, len(attention_offsets)))

7、Smaller keys

由于translation-invariant attention bias偏置项减少了key对位置信息编码的压力,因此LeViT减少了key矩阵相对于V矩阵的大小。如果key大小为, V则有2D通道。key的大小可以减少计算key product 所需的时间。

对于没有残差连接的下采样层,将V的维数设置为4D,以防止信息丢失。

8、Attention activation

在使用常规线性投影组合不同Heads的输出之前,对product 应用Hardswish激活。这类似于ResNet bottleneck residual block,V是一个1×1卷积的输出,对应一个spatial卷积,projection是另一个1×1卷积。

9、Reducing the MLP blocks

在ViT中,MLP residual块是一个线性层,它将嵌入维数增加了4倍,然后用一个非线性将其减小到原来的嵌入维数。但是对于视觉架构,MLP通常在运行时间和参数方面比注意Block更昂贵。

对于LeViT, MLP是1x1卷积,然后是通常的BN。为了减少计算开销,将卷积的展开因子从4降低到2。一个设计目标是注意力和MLP块消耗大约相同数量的FLOPs

2.3 LeViT家族

3 实验

3.1 速度对比

ResNet50的精度,但是是起飞的速度。

3.2 SOTA对比


论文PDF和代码下载

(0)

相关推荐