Efficient Transformer,泛指一切致力于降低Transformer的二次复杂度的工作,开始特指针对Attention的改进,后来更一般的思路,如傅立叶变换、线性RNN等,也被归入这个范畴。不得不说,为了降低Transformer的二次复杂度,各路大牛可谓是“八仙过海,各显神通”,各种神奇的思路“百花齐放”,笔者也从中学习到了不少理论知识。然而,尽管Efficient Transformer在理论上是精彩的,但实际上该领域一直都是不愠不火的状态,并没有实际表现十分出色的模型,在LLM火爆的今天,甚至已经逐渐淡出了大家的视野,也淡出了笔者的兴趣范围。

不过,最近有一篇论文《Transformer-VQ: Linear-Time Transformers via Vector Quantization》,却让笔者为之拍案叫绝。作者非常高明地洞察到,只需要对标准Attention的Key做一下VQ(Vector Quantize),复杂度就会自动降低为线性!这种线性化思路保留了标准Attention的形式,是标准Attention到线性Attention的一个完美过渡,同时最大程度上保留了标准Attention的能力。

高效难题 #

说起来,本站也算是比较早关注Efficient Transformer相关工作了,最早可以追溯到2019年解读Sparse Transformer的一篇博客《为节约而生:从标准Attention到稀疏Attention》。此后,陆续写的关于Efficient Transformer的其他博文还有

《线性Attention的探索:Attention必须有个Softmax吗?》

《Performer:用随机投影将Attention的复杂度线性化》

《Nyströmformer:基于矩阵分解的线性化Attention方案》

《Transformer升级之路:3、从Performer到线性Attention》

《线性Transformer应该不是你要等的那个模型》

《FLASH:可能是近来最有意思的高效Transformer设计》

《Google新作试图“复活”RNN:RNN能否再次辉煌?》

然而,正如本文开头所说,尽管Efficient Transformer已有不少工作,也曾被大家寄予厚望,但实际上该领域一直都没什么能“出圈”的作品,这其中的原因可能是:

1、不少Efficient Transformer的提速以牺牲效果为代价;

2、很多Efficient Transformer的复杂度降低仅仅是理论上的,实际使用提升不明显;

3、有些Efficient Transformer难以用来训练Causal LM,所以在LLM流行的今天就没有了用武之地;

4、Flash Attention的出现表明即便是标准的Transformer仍有很大的提速空间。

VQ一下 #

那么,Transformer-VQ为何又具备的“出圈”潜力?

简单来说,Transformer-VQ就是对Attention的Key向量序列进行了“聚类”,并用所属类的类别中心近似原向量,然后Attention的复杂度就变成线性了。也就是说,Transformer-VQ仅仅改变了Key的形似,其余部分(理论上)完全不变,所以这是一种对Attention改动非常小的线性化方案,也能非常清楚体现出线性化后损失的精度在哪里(即用类别中心近似原向量的差距)。

铺垫得有点多了,现在我们正式介绍Transformer-VQ。首先,我们假设$Q,K\in\mathbb{R}^{n\times d_k},V\in\mathbb{R}^{n\times d_v}$,标准Attention就是
\begin{equation}softmax\left(QK^{\top}\right)V\end{equation}
简单起见,这里省略了scale factor。Transformer-VQ改为
\begin{equation}softmax\left(Q\hat{K}^{\top}\right)V,\quad \hat{K} = \color{skyblue}{\mathcal{VQ}}(K, C)\label{eq:vq-att}\end{equation}
其中$C\in\mathbb{R}^{c\times d_k}$是训练参数,也是VQ的编码表(Codebook)。对了,这里的“VQ”就是指VQ-VAE中的VQ,不了解的读者可以移步参考《VQ-VAE的简明介绍:量子化自编码器》《简单得令人尴尬的FSQ:“四舍五入”超越了VQ-VAE》,这里不重复介绍了。总之,经过$\color{skyblue}{\mathcal{VQ}}$之后,最直接的表现就是$K$的每个向量都变成了$C$中与之最相近的那个,这意味着$\hat{K}$的每个向量都是$C$的向量之一,用数学的语言就是说$K\in\mathbb{R}^{n\times d_k}$变成了$\hat{K}\in C^n$。

Encoder #

当然,直接按照式$\eqref{eq:vq-att}$去实现Transformer-VQ的话,复杂度还是二次的,但由于$\hat{K}$的每个向量都是$C$的向量之一,所以我们可以先算$\exp\left(QC^{\top}\right)$,然后从中“挑出”$\exp\left(Q\hat{K}{}^{\top}\right)$对应的结果,而由于$C$的大小是固定的,所以关键运算$QC^{\top}$的复杂度是线性的,这就是Transformer-VQ能线性化的原理(我们不妨称为“挑出”技巧)。

作为铺垫,我们先考虑双向注意力的Encoder情形。由于
\begin{equation}softmax\left(QK^{\top}\right)V = \frac{\exp\left(QK^{\top}\right)V}{\exp\left(QK^{\top}\right)1_{n\times 1}}\label{eq:softmax-qkv}\end{equation}
这里$1_{n\times 1}$指的是$n\times 1$大小的全1矩阵,分母可以视为分子的一个特殊形式,所以我们只需要考虑分子$\exp\left(QK^{\top}\right)V$。由于$\hat{K}$的每个向量都是$C$中之一,所以我们可以构建一个one hot矩阵$\Delta\in \{0,1\}^{n\times c}$,其中$\Delta_i\in\{0,1\}^c$是一个one hot向量,如果1所在的维度为$j$,那么$\hat{K}_i = C_j$,于是$\hat{K}=\Delta C$。

于是对于Transformer-VQ来说有
\begin{equation}\exp\left(Q\hat{K}{}^{\top}\right)V = \exp\left(QC^{\top}\Delta^{\top}\right)V = \exp\left(QC^{\top}\right)\Delta^{\top}V = \exp\left(QC^{\top}\right)(\Delta^{\top}V)\end{equation}
很明显,这里最关键的地方就是第二个等号!对于one hot矩阵$\Delta$,右乘以它的转置可以从$\exp$中分离出来,这就是原理中的“挑出”技巧的数学表述。分离出来之后,由于矩阵乘法结合律,$\Delta^{\top}$可以先跟$V$相乘,得到一个$c\times d_v$的矩阵,而$\exp\left(QC^{\top}\right)$是一个$n\times c$的矩阵,乘以$\Delta^{\top}V$就得到一个$n\times d_v$的矩阵,总的理论复杂度是$\mathscr{O}(ncd_k + ncd_v + ncd_v) = \mathscr{O}(n)$。

最后,根据式$\eqref{eq:softmax-qkv}$,将$\exp\left(Q\hat{K}{}^{\top}\right)V$的结果代入去,就可以计算完整的Attention结果(可能还要加一些避免溢出的细节),整个过程可以在线性复杂度内完成。

Decoder #

现在我们来考虑单向注意力的Decoder,这是训练生成模型的关键,也是当前LLM的基础。有了Encoder的铺垫后,Decoder理解起来也就没那么困难了。假设$Q_i, \hat{K}_j \in \mathbb{R}^{1\times d_k}, V_j\in\mathbb{R}^{1\times d_v}$是向量序列$Q,\hat{K},V$的行向量之一,那么对于Decoder的分子有
\begin{equation}\begin{aligned}
O_i =&\, \sum_{j\leq i}\exp\left(Q_i\hat{K}{}_j^{\top}\right)V_j = \sum_{j\leq i}\exp\left(Q_i C^{\top}\Delta_j^{\top}\right)V_j \\
=&\, \sum_{j\leq i}\exp\left(Q_i C^{\top}\right)\Delta_j^{\top}V_j = \exp\left(Q_i C^{\top}\right)\sum_{j\leq i}\Delta_j^{\top}V_j
\end{aligned}\end{equation}
如果$c\times d_v$不大,那么最后的式子可以直接用$\text{cumsum}$算子完成,不过一般情况下,尤其是Multi-Heaad时,为了节省显存,通常是跟《线性Attention的探索:Attention必须有个Softmax吗?》中的“自回归生成”一节一样,转为RNN来递归计算,即设$U_i = \sum_{j\leq i}\Delta_j^{\top}V_j\in\mathbb{R}^{c\times d_v}$,那么
\begin{equation}O_i = \exp\left(Q_i C^{\top}\right)U_i,\quad U_i = U_{i-1} + \Delta_i^{\top}V_i
\end{equation}
在推理阶段这样step by step递归计算自然是没问题,但训练阶段step by step的话可能会比较慢,我们可以改为block by block来加速:不失一般性,设$n=lm$,$l$代表block_size,$m$代表block数目,block切片$[il:(i+1)l]$简写为$[i]$,那么
\begin{equation}\begin{aligned}
O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j\lt i}\exp\left(Q_{[i]}\hat{K}{}_{[j]}^{\top}\right)V_{[j]} \\
=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j\lt i}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\
=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j\lt i}\Delta_{[j]}^{\top}V_{[j]} \\
\end{aligned}\end{equation}
其中$M\in\{-\infty,0\}^{l\times l}$是下三角的Attention Mask,即当$i \geq j$时$M_{i,j}=0$,否则$M_{i,j}=-\infty$。于是记$U_i = \sum_{j\lt i}\Delta_{[j]}^{\top}V_{[j]}$后,我们有
\begin{equation}O_{[i]} = \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-1},\quad U_i = U_{i-1} + \Delta_{[i]}^{\top}V_{[i]}
\end{equation}
这样我们就将递归步数减少为$m$了,可以在保证线性效率的同时,更充分发挥硬件的并行能力。用同样的方式也可以计算分母,最后相除得到完整的Attention结果

局域增强 #

就这样完了?并不是,如果仅仅是这样的话,Transformer-VQ可能跟以往基于矩阵分解的Kernelized Attention如Performer并没有太多区别。当序列长度$n$远大于编码表大小$c$时,由抽屉原理我们知道部分编码向量必然会反复出现,甚至可以合理猜测所有编码向量应该会均匀分布在整个序列中。这样一来,邻近token的Attention就会跟远处某些token的Attention一样,也就是说模型无法区分远近,这本质上就是所有Kernelized Attention都存在的低秩问题。

已有的经验告诉我们,对于语言模型来说,相对于远处的token的来说邻近的token往往更为重要,所以一个好的语言模型架构应该具有区分远近的能力。为此,Transformer-VQ选择在$Q\hat{K}$之后,加上一个Sliding Window形状的Attention Bias(记为$B$),来对邻近token进行加权,如下图:

Window Attention Bias示意图

Window Attention Bias示意图

从最后一个图可以看出,如果将Window大小直接设为block大小$l$,即$i < j$或者$i - j \leq l$时$B_{i,j}=0$,那么在分block计算时,矩阵$B$顶多影响最邻近的两个block,再远的block依旧可以用“挑出”技巧来线性化。为了便于下面的推导,我们记$B_{[i,j]} = B_{[il:(i+1)l,jl:(j+1)l]}$,那么
\begin{equation}\begin{aligned}
O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j\lt i-1}\exp\left(Q_{[i]}\hat{K}{}_{[j]}^{\top}\right)V_{[j]} \\
=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j\lt i-1}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\
=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j\lt i-1}\Delta_{[j]}^{\top}V_{[j]} \\
\end{aligned}\end{equation}

所以很明显,有(约定$V_{[-1]},U_{[-1]},U_{[-2]}$都是全零矩阵)
\begin{equation}\begin{aligned}
O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-2}\\[5pt]
U_i =&\, U_{i-1} + \Delta_{[i]}^{\top}V_{[i]}
\end{aligned}\label{eq:tvq}\end{equation}
笔者认为,$B$的引入是Transformer-VQ是跟其他Kernelized Attention拉开差距的关键,为了减少参数量且支持变长生成,我们约束B的非零部分为“Toeplitz矩阵”,即$B_{i,j}$是$i-j$的函数,此时$B$就相当于加性相对位置编码。除了这种做法外,也可以考虑换为笔者之前提出的ReRoPE,它是旋转位置编码的窗口版,跟$B$具有同样的相对位置编码形状。

梯度回传 #

等等,我们好像忘记了点什么。了解VQ-VAE的读者都知道,“$\hat{K}$的每个向量都是$C$的向量之一”只是前向传播的表现,反向传播用的可是原始的$K$,这意味着即便不同位置的$\hat{K}_j$等于同一个$C_k$,但它们的梯度却不相等,这叫做STE(Straight-Through Estimator)。由于STE的存在,“挑出”技巧理论上仅可用于推理阶段,训练阶段是无法线性化的。

没有其他办法了吗?确实如此,如果我们坚持要获得精确的梯度结果,那么并没有线性化效率的方案。然而,考虑到VQ的梯度本身就是近似的,所以Attention获取精确的梯度似乎也没多大必要。于是作者想了个折衷的方案:依然是按照式$\eqref{eq:tvq}$进行递归计算,仅在前两项使用STE(Key序列可以获得梯度),而$U_{i-1}$的梯度直接停掉($\text{stop_gradient}$算子)。这样我们就保持了模型的线性性,同时也已经保留了最重要的梯度(邻近的两个block),算是一个比较合理的近似方案。从这一点来看,Transformer-VQ跟Transformer-XL很像,Transformer-XL在递归的同时也停掉了历史窗口的梯度,即历史窗口可以参与递归计算,不传递梯度。

解决了梯度回传问题之后,在自回归交叉熵损失的基础上,再上VQ带来的用来更新编码表的辅助loss,就得到完整的训练目标了。当然,对于编码表的更新,Transformer-VQ采用了直接滑动平均的方案,所以只补充了Key的辅助loss,这些细节读者在熟悉VQ-VAE之后,稍微看一下原论文就理解了。

实验结果 #

这一节我们来看一下原论文的实验结果。作者已经将代码开源如下:

值得指出的是,作者做VQ的基础架构并不是常规的MHA(Multi-Head Attention),而是笔者一直很推崇的GAU(Gated Attention Unit)+Softmax,Transformer-VQ更准确的命名应该是“GAU-VQ”,不了解GAU的读者可以参考《FLASH:可能是近来最有意思的高效Transformer设计》《听说Attention与Softmax更配哦~》。简单来说,GAU本身比MHA有着更高的效率,配合上VQ技巧后,就更加“如虎添翼”了。

实验方面,作者做了语言模型(ENWIK8、PG-19)和图像生成(IMAGENET64),所有的实验中的编码表大小都是$c=512$。模型最大参数量为1.3B,虽然比不上主流的大模型参数量,但其实对于科研来说不算小了。实验结果总体来说算得上优异:

PG-19的实验结果

PG-19的实验结果

IMAGENET64的实验结果

IMAGENET64的实验结果

最后,让人惊奇的是,Transformer-VQ的作者只有一个,并且身份是“Independent Researcher”。

发散思考 #

笔者发现,从Transformer-VQ出发,可以联系到非常多的研究主题,这也是为什么笔者如此欣赏它的原因之一。

首先,再次为作者惊人的洞察力点赞,“只需VQ一下Key,Transformer的复杂度就会变成线性”这个发现实在太美妙了,它实现了标准Attention到线性Attention的自然过渡,并且可以通过加Attention Bias的方式让它比很多的Kernelized Attention都有效。然后,通过VQ进行“聚类”的方式,也比LinformerNyströmformer等更为高明,因为它防止了未来信息的泄漏,可以自然地用来做Causal的语言模型。

我们知道,VQ本质上也是将序列转为离散id的运算,这跟Tokenizer的作用是非常相似的。从这个角度来看,Transformer-VQ跟MegaByte等模型一样,都是将Tokenizer内置在模型之中,并且相比MegaByte,VQ这一操作跟我们传统意义上的Tokenizer更为相似、直观。所以,Transformer-VQ实际上非常适合用来训练直接以Bytes输入的“No Tokenizer”模型,事实上,上述ENWIK8实验就是Bytes输入,Transformer-VQ效果明显优于MegaByte。

相比近来出的RetNet,Transformer-VQ没有显式的远程衰减,所以Long Context能力有可能会更好,同时由于Key经过了VQ,都是有限集合之一,所以不会出现没有学过的Key,因此长度外推能力大概率也会更好。虽然Transformer-VQ的基础架构GAU只是Single-Head的,但它在递归过程中模型记忆状态大小是$\Delta_i^{\top}V_i\in\mathbb{R}^{c\times d_v}$,在默认的设置中,这比Multi-Head的RetNet还大(RetNet的记忆状态大小是$nd_k^2$,默认设置下$d_v = 2nd_k$),因此,记忆容量理论上是足够的。

由于上一篇文章刚好写了《简单得令人尴尬的FSQ:“四舍五入”超越了VQ-VAE》,可能会有读者想知道可否用更简单的FSQ取代VQ?笔者认为比较难,原因其实在上一篇文章给出了:第一,$c=512$还属于VQ优于FSQ的编码数量范围,所以换FSQ大概率会掉效果;第二,由于每层Attention的Key都要被VQ,所以平均来说VQ的Encoder和Decoder都不强,这种情况VQ近似精度更高,FSQ更适合Decoder和Decoder都足够强的场景;第三,Transformer-VQ需要用的是Key被VQ之后的中心向量而不是id,而FSQ则直接得到id,反而不容易恢复为近似的中心向量。

除此之外,用VQ而不是FSQ,使得Transformer-VQ有希望从现有的预训练模型如LLAMA2中微调过来,而不单单是从零训练。因为VQ具有鲜明的几何意义,跟K-Means有诸多相通之处,我们可以从现有预训练模型出发,选取一些样本计算出Key,对Key进行K-Means得到中心向量作为编码表的初始化,然后在原模型基础上加上VQ进行微调。不过Transformer-VQ不大好适配RoPE,所以要如前面所说,RoPE的模型要换成ReRoPE再VQ比较好,此时就可以不用加Bias了。

总之,在笔者眼中,Transformer-VQ在众多Efficient Transformer工作中,是非常独特、出色而又潜力深厚的之一。

文章小结 #

本文介绍了一个名为Transformer-VQ的Efficient Transformer方案,它基于“只需VQ一下Key,Transformer的复杂度就会变成线性”的观察结果进行展开,个人认为是一种非常独特且亮眼的线性化思路,实验结果也很优异。它既可以理解为一种更高明的线性Attention/RNN模型,也可以理解为一个带有“可训练的Tokenizer”的Attention模型。

转载到请包括本文地址:https://www.kexue.fm/archives/9844

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Nov. 09, 2023). 《VQ一下Key,Transformer的复杂度就变成线性了 》[Blog post]. Retrieved from https://www.kexue.fm/archives/9844

@online{kexuefm-9844,
        title={VQ一下Key,Transformer的复杂度就变成线性了},
        author={苏剑林},
        year={2023},
        month={Nov},
        url={\url{https://www.kexue.fm/archives/9844}},
}