无乘加矩阵乘法MADNESS

作者介绍:Davis Blalock :麻省理工学院的博士生,由John Guttag 教授指导。主要研究方向是设计高性能的机器学习算法,以减少人们在机器学习速度、准确性、隐私和安全性之间的妥协。John Guttag:麻省理工学院计算机科学与电气工程教授、ACM Fellow,领导麻省理工学院计算机科学和人工智能实验室(CSAIL)的临床与应用机器组。该小组负责开发和应用先进的机器学习和计算机视觉技术,以解决各种临床相关问题。

矩阵乘法是机器学习中最基本和计算密集型的操作之一,因此其速度对于整个机器学习算法速度来讲至关重要。研究者因此设计了一种基于零乘加法,以哈希、平均和字节变换作混合为核心操作的新算法,使其运行速度基本上能达到比精确的矩阵乘积快100倍,并且比目前所有的近似方法快10倍左右。

算法背景

该算法基于内积量化算法(Product Quantization),简称PQ,它是一种经典的近似于矩阵内积和欧几里得距离的向量量化算法,是几乎所有类似方法的基础。该算法基于的想法,其中很小,且具有特殊的结构,允许内积被快速地计算出来。PQ算法具体来讲分为以下四个步骤:

一、原型学习(Prototype Learning)

设是一个训练集,K是每个子空间的原型数,C是多个子空间,是与每个子空间相关的互斥的、集体详尽的索引集。PQ的训练任务是学习C组原型和分配,使得:最小化。它通过在每个子空间中分别运行K-means算法,并使用生成的质心和赋值来填充和来实现这一点。

二、编码函数(Encoding Function)

给定学习到的原型,PQ将A的每一行a替换为每个C子空间中的K-means质心分配。形式为:并将把得到的索引序列称为a的编码,并将K质心集作为编码本。

三、建立表(Table Construction)

使用这些相同的原型,PQ算法在B的每列b的每个C子空间中构造了一个查找表,其中:且研究者的实验表明,与较大的K或浮点表相比,将K=16和将查找表量化为8位可以提供巨大的加速。

四、汇集(Aggregation)

给定a的编码和b的查找表,乘积可以近似为:直观上来理解则如下图所示:g(·)函数返回与每个子空间中的数据向量a最相似的原型的索引。h(·)函数计算查询向量b和每个子空间中的每个原型之间的点积的查找表。聚合函数f(·,·)对每个索引对应的表条目进行汇总。

图1(文献[1]中Figure 2)

作者在PQ算法的基础上,引入了一个新的g(a)函数,使在更小的矩阵上也能产生很大的加速。该函数背后的主要思想是通过局部敏感散列来确定“最相似”的原型,也就是说,该算法不是计算一个子向量和每个原型之间的欧氏距离,进而选取距离最近的原型作为近似,而是将散列到一个K-buckets(实际上就是K位散列值),其中相似的子向量倾向于散列到同一bucket。将原型(prototype)设置为散列到同一个bucket的子向量的平均值。

基于这个想法研究者试图采用一个训练集,同时希望做的工作甚至比一个线性变换还要少得多,结果发现现有的哈希函数不满足要求。因此,作者设计了自己的可训练哈列函数族。所选择的哈希函数族是平衡二叉回归树,树的每一个叶子节点作为一个哈希bucket。如果某个索引j上的值低于节点特定的阈值v,则通过从根目录遍历树并移动到左子节点来选择向量x的叶子,右子节点也是如此。算法1如下:

算法1

形式上,考虑一组四个指标和四个分割阈值数组,vt的长度为。该算法将一个向量x映射到一个散列值索引上。这个函数只依赖于输入向量中恒定数量的索引,如果被编码行的矩阵按列主序存储,就可以很容易地向量化。进一步通过算法2学习哈希函数的参数,分割指数和分割阈值利用贪婪树构造算法对训练矩阵进行了优化。为了描述这个算法,研究者引入了bucket()的概念,它是在树的t级(对应平衡二叉回归树节点的深度)中映射到节点i(同一深度的某个节点)的向量集。树的根是0级,包含所有的向量。定义误差(SSE)损失的平方和如下:

算法2如下:

算法2

对原型进行优化

研究者并没有直接将基于哈希的编码函数放到PQ和近似矩阵积中,而是提供了两种额外的改进方法:一种无运行时开销的原型优化方法,以及一种快速求和低比特宽整数的方法。

一、选择原型

设是一个矩阵,其大小为的对角线块由每个子空间c中的K个学到的原型组成。训练矩阵可以近似地重建为,其中G可以在每个子空间中选择合适的原型。G的行通过连接对应一行的每个分配的一个one-hot编码表示而形成。现在需要在G和的条件下优化P。这是一个普通的最小二乘问题,可以用岭回归来解决:岭回归中可以通过交叉验证找到λ来获得更好的性能,文中为了简单起见,固定λ=1。

二、快速8位汇集 设为B中所有M列的查找表的张量。给定编码G,将函数定义为:在非饱和加法指令的情况下, 因为T的条目存储为8位值,所以在执行任何求和指令之前如果立即将每个查找条目上转换到16位,将导致其存取吞吐量是直接用8位值的一半。因此该方法求和过程不使用additional指令,而是采用averaging指令,如此做法之后计算结果造成的精度损失将发生在低bit位而非高bit位,整体误差会变大,但计算速度大大提升,即选择牺牲少量的准确性来显著提高速度。

MADDNESS 到底有多快?

研究者首先分析了MADDNESS的原始速度。他们为各种矢量量化方法计算 g(A) 函数的时间,结果表明,MADDNESS 比现有方法快两个数量级,其吞吐量随矩阵行的数量而增加。

且如下图所示,MADDNESS的速度几乎是 Bolt(上图蓝色虚线)的两倍:

在广泛使用的 CIFAR-10 和 CIFAR-100 数据集上和近似Softmax线性分类器的对比也可以看出,MADDNESS 实现了比任何现有方法更好的速度精度权衡:

进一步,为了评估该方法在更大、多样性更强的数据集上的表现,研究者在来自UCR的Time Series Archive 数据集上训练了kernel分类器。结果表明,MADDNESS 在给定准确率的情况下明显快于其他方案:

为了测试MADDNESS的极限,研究者对将小滤波器应用于图像的各种技术能力进行了基准测试。结果表明,只有 MADDNESS 比精确矩阵乘积更有优势:

参考文献

[1] Blalock, D. & Guttag, J. (2021). Multiplying Matrices Without Multiplying (cite arxiv:2106.10860Comment: To appear at ICML 2021)

(0)

相关推荐