Deja Vu技术显著提高大语言模型推理速度的实测数据对比与最新排行榜
ICML 2023 上这篇工作别具匠心,核心思路是在推理过程中,根据当前输入 X 动态地选取部分网络参数参与计算,而不是每次都动用全部权重。从本质上看,这相当于一种“动态剪枝”策略。苹果近期推出的“LLM in flash”方案,几乎就是这项技术的直接工程化落地。由此可见,Deja Vu 在移动端部署大语言模型(尤其是低算力场景)上,拥有相当可观的潜力。
更关键的是,这套方法实现门槛并不高。
在算力极度受限的环境下,提升大语言模型的推理速度,实际价值非常突出。剪枝作为模型轻量化的关键手段,其核心前提是模型参数存在稀疏性——也就是说,我们可以剔除一部分不那么重要的参数,在尽量不影响模型表现的前提下,把参数量降下来。
实际落地时,常规做法是用特定的剪枝算法,从一个较大的模型(暂且称为模型 A)出发,剪出一个较小的模型(模型 B)。推理阶段就用模型 B 完成计算,过程如图 1 所示。
在这种模式下,对于任意输入 X,模型 B 的生成过程都与 X 本身无关。这类方法可称为“静态剪枝”,其背后的稀疏性假设就是“静态稀疏性”。
而 Deja Vu 的做法截然不同——模型 B 的生成与输入 X 直接挂钩。因此它属于“动态剪枝”,对应的假设是“动态稀疏性”。原论文使用的术语是 Contextual Sparsity,强调的正是模型 B 的产生过程依赖输入 X,如图 2 所示。
图 2 清晰地展示了模型 B 的生成过程受输入 X 影响。Deja Vu 要解决的核心问题是:如何基于输入 X,从预训练好的大模型 A 中,快速“变出”一个针对性的小模型 B。
1. Contextual Sparsity
Contextual Sparsity 的核心思想,在于强调对于已经训练好的大模型 A,其参数的重要性是随输入 X 变化的。举个例子,对于两个不同的输入 X1 和 X2,通过图 2 所示的方式剪枝得到的模型 B 并不相同。这就是“上下文”的含义——模型 B 依赖于上下文,而这里的上下文指的就是模型输入。
那么关键问题来了:大模型中真的存在 Contextual Sparsity 吗?
论文的验证方法相当直接。首先,用输入 X 做一次前向推理,过程中记录下那些输出具有较大 L2 范数的 MHA(多头注意力)中的 head,以及 MLP 中的神经元。
具体实现起来也不复杂。以 MHA 为例,每个 head 的输出都是一个矩阵,只需对这个矩阵计算 L2 范数,然后从所有 head 中挑出范数最大的那几个,如下图所示。
在图 1.1 的示例里,输入 X 的序列长度 N=10,维度 d=6。假设 MHA 的 head 数量是 3,那么 MHA 的输出就对应 3 个不同的 head,图中用不同颜色区分。只需计算每个 head 对应输出的 L2 范数,然后找出较大的 head。
图1.1中的MHA忽略了最后的线性层
对于 MLP,情况略有不同。针对某个特定的 token,MLP 的输出是一个向量。但此时不能直接算整个向量的 L2 范数,而要看哪个输出维度的 L2 范数更大。这是因为整个向量是由所有神经元共同计算出来的,而每个维度正好对应一个神经元,如下图所示。
MLP 是逐位置执行的,所以它的物理意义应该用图中红色框的形式来理解:每一列对应输入中每一个 token 的变换。但在找有效神经元时,需要按行来挑选。换句话说,要计算每一行的 L2 范数,然后挑出范数较大的神经元。
找到这些 L2 范数较大的 head 和神经元后,论文用同样的输入 X 再做一次前向推理,但这次只让被挑出来的部分 head 和 MLP 神经元参与计算。结果发现,仅用这些经过筛选的组件,几乎不影响模型的效果。
也就是说,通过这个简单的实验,作者们确认了大语言模型中确实存在 Contextual Sparsity:模型中有一些与输入强相关的高效参数。只用这些参数,就能达到和全参数模型几乎一致的表现。
再重复一遍,不同的输入 X,对应的高效参数部分是不一样的。这与传统的剪枝方法有本质区别,也正是 Contextual Sparsity 这个名字的由来。
注:Transformer原文中,MHA每个head的输出是通过拼接成一个与输入尺寸一致的张量,再接全连接层做变换。这篇论文因为需要挑选部分head,流程上稍有调整:每个head的输出会先接一个全连接层,变换成与输入尺寸一致的张量;然后对所有head的输出求平均和。这样一来,无论选多少个head,MHA的输出尺寸都是一样的。
前面提到,可以通过计算 Attention head 和 MLP 神经元输出的 L2 范数来找到“高效参数”。那么,这个比例大概是多少呢?例如,我们算出了所有 head 和所有神经元的 L2 范数,究竟要选排名前多少的,才能作为“高效参数”?
论文基于 OPT 模型的实验结果如下:
Attention head 的稀疏率大约在 80% 左右
MLP 中神经元的稀疏率大约在 95% 左右
这意味着,实际推理时,我们只需要用大约 20% 的 Attention head 和大约 5% 的 MLP 神经元,就能达到和全参数模型差不多的效果。
2. 稀疏性预测
要利用前面提到的稀疏性来加速推理,就需要有方法能提前、准确地预测,对于当前的输入 X,哪些 head 和哪些 MLP 神经元是“高效参数”。
这里需要两个预测模型。一个用来预测 MHA 里哪些 head 是“高效的”;另一个用来预测 MLP 里哪些神经元是“高效的”(这里的“神经元”实际上指向参数矩阵的某一列或某一行,具体取决于参数矩阵的定义方式)。
在 Deja Vu 中,这两个模型的实现都采用了一个两层 MLP。
以预测 Attention head 编号为例,假设 head 数量是 256,那么只需将 MLP 的输出层大小设为 256,并为每一个输出加上 sigmoid 做二分类即可(选择或不选择)。
训练数据来自一个完整训练好的大模型。在该模型推理的过程中,记录下它的 Attention 输入和 Attention 输出,算出不同 head 的 L2 范数,然后基于一个 L2 范数的阈值 t,把 head 分成正例和负例。
预测 MLP 中需要选择的神经元编号,思路和上面基本一致。
下面以一个 Transformer 模块为例,展示一种朴素实现方法。
公式(1)到公式(4)的流程在逻辑上没问题,但它在原本的流程里额外加了两个步骤:预测 head 编号(公式1)和预测 MLP 神经元编号(公式3)。这可能导致网络的整体速度比原来用全参数时还慢!
下一部分会介绍论文里采用的高效实现方案。
3. 高效实现
基于MLP的稀疏性预测必须做到尽可能高效,否则整个大模型的推理时间可能因为引入额外的MLP而变得更慢。这部分整理一下论文里用到的一些高效实现策略。
3.1 并行化稀疏性预测
先说方案,再说原因。
因此原本只能串行执行的四个步骤(公式1到公式4),其中预测编号的两步可以和剩下的两步并行执行。
但为什么可以这样操作?
论文给出的理由是:大语言模型中 token 的 embedding 变化非常缓慢。
因为变化慢,所以提前一步去预测似乎也合情合理。论文用两张图说明了 token embedding 变化缓慢的现象。
图3.1中的左图,展示的是连续两个网络层之间 token embedding 的余弦相似度,高得离谱。右图则是间隔 n 层的 token embedding 之间的余弦相似度。两张图都很直观地表明,大语言模型里 token 的 Embedding 变化确实非常缓慢。因此,用前一层的输入去提前预测下一层的两个编号,是合理的。
3.2 Kernel Fusion
Kernel Fusion 算是绝大多数优化工作中的标准操作。在具体动手优化前,首先得想清楚 Kernel Fusion 是否可行。
在 PyTorch 里实现论文中的稀疏矩阵乘法,需要先用预测的编号索引,从参数矩阵里取出对应的参数。这会带来 3 次 I/O:1) 读参数矩阵 W;2) 取完对应索引后写参数矩阵 W1;3) 读 W1 然后做矩阵乘法。
很明显,对于当前场景,步骤2和步骤3是多余的操作。但在 PyTorch 提供的现有算子下,只能这样执行。
所以一个很直接的策略就是单独写一个 kernel,把步骤1、2、3合并到一起,这样 I/O 只有一次。
Kernel Fusion 这一步,速度提升了 4 倍(仅指这一步计算的速度)。
3.3 Memory coalescing
论文里给的例子感觉不太友好,和文中的符号有点对不上。
4. 其它
论文里详细的实验结果就不一一列举了。基于 OPT 模型,在 75% 的稀疏率下,Deja Vu 实现了 2 到 6 倍的加速,而且没有掉点,效果相当可观。
唯一的遗憾是,与很多大语言模型领域的论文相比,实验量还是偏小了一点。
这不禁让人联想到神经网络的过参数化问题。真是一个有趣的话题,很想多聊几句,但想多了容易乱,暂且点到为止。