深度学习模型复杂度分析
Transformer
self-attention和position-wise FFN作为Transformer比较特殊的模块,这里只分析一下它们的复杂度,注意:这里的复杂度既包含时间,也包含空间。

先看self-attention,复杂度为,举个例子,比如说一个句子:,它经过词嵌入后所对应的矩阵的shape应该为,是序列的长度,而每一个单词对应行的行数,即嵌入的维度,见下图:
假设序列长度为,当和相乘时复杂度为。人为计算时一开始会拿中左上角标黄色的块乘左上角标黄色的块,在计算过程中,中标黄色的块会乘遍的第一行,复杂度为,而一行有个黄色的块,则变为,而中有行,所以总复杂度为,同理可得与矩阵相乘时复杂度为,Transformer在论文中提到做自注意力时设置,因而对于一个句子来说,自注意力的复杂度即为。批处理过程中,假设每个batch里都有个句子,但由于是常数,往往不起决定作用,所以总复杂度还是。

参数矩阵一共有四个:,当所有注意力头出来的值拼接成一个大的矩阵,再与相乘便得到最后的注意力值了,的shape为,在注意力中即为,是注意力头的的数量。于是对一个注意力头做自注意时参数的数量为,算上注意力头的话,总参数应为。
在谷歌原始的论文中,的超参设定默认,同理,两次的线性变换都是,时需要与每一个值进行比较,因而也是,故整个复杂度为,所以总复杂度即为。的shape分别为,因此总参数数量为。

进入全连接层之前序列会被拉直,因此,送进全连接层的序列shape为,的shape为,按照上面的推导,全连接层的复杂度应为。

若对于这个模型不了解的读者,建议先阅读本号的另一篇关于textCNN的文章。
继续上面的设置,输入序列长度为,每个词被嵌入的维度为,假设卷积核的宽度为,每种大小的核一共有个。
见下图,我们以的蓝色卷积核为例,看这个卷积核的左上角元素,在卷积过程中会与个元素相乘,这样的元素在卷积核里一共个,所以对一个卷积核的复杂度应为,即为,并且每一种卷积核一共个,因而总复杂度为。


RNN的范式如下所示,假设。一开始初始化的时候的shape为(这里以一个句子以及单向RNN为例),的shape为,复杂度为,加上之后的shape即为。的shape为同理,后面的shape为,的shape为(RNN按步长输入,并非全部直接输入),后面相乘的复杂度为,因此的运算复杂度为。与相乘的复杂度亦为。每个时间步长上都会进行一次,因而总复杂度为。这里可能有些疑惑,以往的复杂度相加不会算上系数,比如说两次复杂度为,还是,这里不同的是相加为,这是个变量。