前段时间笔者在arXiv上刷到了论文《Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention》,里面描述的实验现象跟我们在训练Kimi K2时出现的一些现象很吻合,比如都是第二层Attention开始出现问题。论文将其归因为低精度Attention固有的有偏误差,这个分析角度是比较出乎笔者意料的,所以饶有兴致地阅读了一番。

然而,论文的表述似乎比较让人费解——当然也有笔者本就不大熟悉低精度运算的原因。总之,经过多次向作者请教后,笔者才勉强看懂论文,遂将自己的理解记录在此,供大家参考。

结论简述 #

要指出的是,论文标题虽然点名了“Flash Attention”,但按照论文的描述,即便block_size取到训练长度那么大,相同的问题依然会出现,所以Flash Attention的分块计算并不是引起问题的原因,因此我们可以按照朴素的低精度Attention实现来简化分析。

简单起见,我们只分析单头Attention,设$\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}\in\mathbb{R}^{n\times d}$,记$\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^{\top}$,加粗的$\boldsymbol{1}$是指$n\times 1$的全1矩阵,$\boldsymbol{S}_{\max}$则指$\boldsymbol{S}$每行取最大值后得到的$n\times 1$矩阵,那么
\begin{equation}\boldsymbol{O} = \frac{\exp(\boldsymbol{S})\boldsymbol{V}}{\exp(\boldsymbol{S})\boldsymbol{1}} = \frac{\exp(\boldsymbol{S} - \boldsymbol{S}_{\max})\boldsymbol{V}}{\exp(\boldsymbol{S}- \boldsymbol{S}_{\max})\boldsymbol{1}}\end{equation}
我们记$\bar{\boldsymbol{P}} = \exp(\boldsymbol{S} - \boldsymbol{S}_{\max})$,那么Attention的关键计算是矩阵乘法$\bar{\boldsymbol{P}}\boldsymbol{V}$,它一般是在BF16精度下进行。论文给出的结论是:在低精度计算下,$\bar{\boldsymbol{P}}\boldsymbol{V}$这一步存在有偏的舍入误差。也就是说,在长期平均下,低精度计算的$\bar{\boldsymbol{P}}\boldsymbol{V}$跟准确值的差的期望并不是零。

这样一来,不同训练步骤之间的偏差可能就会持续累积,从而引起MaxLogit爆炸、Loss Spike等问题,直至训练崩溃。当然,严格来讲这只能算是MaxLogit爆炸等问题的一种可能的产生机制,不一定是全部,但即便如此,也值得我们学习和思考一番。

向偶舍入 #

为了理解论文结论,我们先来补习一些关于舍入误差的基本常识。之所以会写这一节,原因开头就说了——笔者本身并不熟悉低精度运算——所以这一节完全是写给自己补基础的,对此已有了解的读者完全可以略过。

我们知道,常用的舍入(Round)方式是“四舍五入”:在10进制中,一个正的1位小数要舍去最后一位,0~4就会变成0,产生的误差是$0,-0.1,-0.2,-0.3,-0.4$;5~9就会变成10,产生的误差是$0.5,0.4,0.3,0.2,0.1$。不知道大家发现没,这些误差的平均值并不是0,而是0.05,即“四舍五入”平均而言会放大原来的数字,产生正的偏差。

当然,相对偏差会随着舍去位数的增加而减少,比如一个2位小数要舍去2个小数位,平均误差则是0.005。但不论如何,四舍五入的这个正偏差总是存在的,只不过是大小不同。偏差的根源在中间点,比如0.51和0.49,它们分别往上/往下舍入,误差刚好抵消,但对于0.50不管规定它往上舍入还是往下舍入,都没有另一个数跟它抵消误差。

为了消除偏差,IEEE 754 提出了“向偶舍入(Round-to-Even)”原则,它规定对于中间情形,应该按照靠近偶数的方向舍入,比如2.5舍去最后一位要变成2,但3.5舍去最后一位则变成4,这样“5”就各有一半的几率产生$\pm 5$的误差,平均误差变为零,从而消除了偏差。

回到计算机领域。我们知道计算机使用二进制,它只有0和1,那么1就起到了10进制的“5”的角色。二进制中“四舍五入”的偏差更形象,因为末位只能是0或1:如果是0,自然不用改变,而如果是1,则触发“五入”而进1。所以,二进制数按“四舍五入”舍去末位,结果必然大于或等于原数,因此也需要“向偶舍入”来消除偏差。

BF16加法 #

接着我们重温一下BF16格式。BF16用16位二进制表示一个浮点数,其中1位符号、7位尾数、8位指数,8位指数的设计让它表示的范围跟FP32(1位符号、23位尾数、8位指数)一致,这也使它成为如今LLM训练的主要浮点格式。

BF16保留了较多的指数位,代价必然是尾数较少,从而能表示精度较低。为了缓解低精度带来累积误差,BF16运算采取的策略是“FP32累加”,也就是说BF16数的累加都是先转换成FP32,然后在FP32空间中相加得到FP32的结果,最后再转回BF16的。

现在我们考虑两个符号和指数相同的BF16数字相加。为什么要选指数相同来分析呢?因为我们要估计误差,指数相同意味着这两个数同数量级,相加后最有可能产生最大的误差。举个例子,如果两个数相加的数相差100倍,那么我哪怕直接返回最大者,误差也不过1%,所以最大误差往往在同数量级的数相加时发生。

两个符号和指数相同的BF16数字相加,必然会出现进位,比如“1.0000001 + 1.0000100 = 10.0000101 = 1.00000101 × 10”,这时候需要指数加1,并且舍去最后一位1,才能转换成BF16格式。如上一节所述,如果按照“四舍五入”舍去末位,那么将会产生正的偏差。不过我们已经知道,科学家早已发现了这个偏差,因此提出了“向偶舍入”来消除偏差。

两大一小 #

所以,到目前为止,一切结果都在可控和预期的范围内,还没有偏差产生。然而,不出意外的话,意外出现了。

现在让我们考虑三个同符号的数相加,这三个数的特点是:其中两个数指数相同且很大,第三个数很小。比如我们在上一节的例子“1.0000001 + 1.0000100”基础上再加上“0.0000000001”,那么得到“1.0000001 + 1.0000100 + 0.0000000001= 10.0000101001 = 1.00000101001 × 10”。

原本两个数相加,结果是“1.00000101 × 10”,舍去末位时会触发“向偶舍入”,得到“1.0000010 × 10”,可现在多了一个极小数,转换成BF16时要舍去的尾数变成了“1001”,比中间点更大,所以触发向上舍入原则,结果是“1.0000011 × 10”。那么在原本两个数相加的视角看来,第三个极小数的出现,破坏了“向偶舍入”规则,使得正偏差再次出现!

当然,这种情况出现条件看上去还是很苛刻的。首先三个数需要同号,其次需要满足“两大一小”,其中两个大数刚好能触发进位,然后小数小到只能影响FP32的尾数(即第9~23位尾数)。这样一来,小数很小,本身舍去都没多大误差,但它的存在,偏偏刚好能破坏了两个大数的“向偶舍入”规则,从而带来了单侧的偏差。

量身定制 #

这么苛刻的条件,实际中真的能出现吗?一般情况情况下还真不容易,但对于Attention来说,这仿佛就是“量身定制”的Bug!

我们取出$\bar{\boldsymbol{P}}\boldsymbol{V}$的某行某列(也就是某个元素),它可以写成
\begin{equation}\sum_{i=1}^n \bar{p}_i v_i \label{eq:sum-pi-vi}\end{equation}
其中$\bar{p}_i = \exp(s_i - \max(s_i))\leq 1$。我们知道,Softmax Attention的特点是能够“集中注意力”,也就是说注意力可能会集中在有限几个Token上,体现在$\bar{p}_i$上就是少数几个Token的$\bar{p}_i$接近于1,剩下的则会非常接近于0,但由于$\exp$的缘故,无法精确等于0(除非下溢出BF16的表示空间)。

然后,随着层数的堆叠和训练的进行,输入$\boldsymbol{V}$可能会出现“各向异性”,其中一种表现是某些维度的正负号分布不均匀,不失一般性,我们假设$v_i$大部分都是正数(负数同理),并且数量级大致相等。那么,求和$\eqref{eq:sum-pi-vi}$可以分为两部分:少数几个能接近于1的$\bar{p}_i$跟$v_i$相乘,成为求和的主项,剩下的余项是大部分接近于0的$\bar{p}_i$与$v_i$相乘。

论文考虑了一个特例:主项对应的几个$\bar{p}_i$并不是接近于1,而是都等于1,也就是$\boldsymbol{S}$的某些行同时存在多个$\max$。这个特例自然更难成立,但更容易理解,此时主项的$\bar{p}_i v_i$固有精度只有BF16。如此一来,“天时地利”俱备,完美触发了上一节说的Bug:

大部分项都是正数,主项精度都是BF16,求和满足进位条件;剩下余项极小,只能影响FP32最末端的尾数,刚好破坏了“向偶舍入”导致偏差;最后,由于“集中注意力”,主项的数目不会多,所以进位也不会太多(舍去位数越多,偏差越小),使得偏差处于显著区间!

这一套组合下来,可不就是为Attention定制的“专属Bug”?

干掉余项 #

了解问题的来龙去脉后,我们再来思考一下怎么解决问题。

表面上看,引发偏差的原因是极小的余项破坏了“向偶舍入”,但更深入思考一下,其实根本原因是“四舍五入”这个规则在中间处存在一个突变点,在突变点附近容易因为扰动而产生偏差,“向偶舍入”虽然能消除偏差,但消除不了突变点。理想的根治办法是Stochastic Rounding,也就是依概率向上/向下舍入,这样最大程度上避免了小扰动带来的偏差。

然而,据说Stochastic Rounding不容易有高效的硬件级实现,所以现在多数硬件的矩阵乘法算子都不带Stochastic Rounding。因此,原论文选择了另一条路径,直接面对问题,其思路笔者称为“干掉余项”。具体来说,在检测到某个触发条件时,我们将Attention的计算公式改为
\begin{equation}\boldsymbol{O} = \frac{\exp(\boldsymbol{S})\boldsymbol{V}}{\exp(\boldsymbol{S})\boldsymbol{1}} = \frac{\exp(\boldsymbol{S} - \beta\boldsymbol{S}_{\max})\boldsymbol{V}}{\exp(\boldsymbol{S}- \beta\boldsymbol{S}_{\max})\boldsymbol{1}}\end{equation}
其中$\beta > 1$。这样一来,每一项都需要多除以$\exp((\beta-1)\boldsymbol{S}_{\max})$,这是一个并不算小的数(论文设置$\beta \geq 2$),于是原本就极小的余项,就容易下溢至零而消失,那么“向偶舍入”便重新发挥作用,从而消除偏差。

那么,检测条件是什么呢?原论文考虑得比较简单,就是矩阵$\boldsymbol{S}$的行出现大于等于两次最大值时,修改就会触发,此时$\bar{p}_i$中至少有两个1。但笔者认为这里肯定有很大调整空间的,算是留下了一个改进方向吧。另外要注意的是,Flash Attention是分Block计算的,所以这个检测条件和修改也是按Block进行,细节可以参考原论文附录的代码。

延伸思考 #

总的来说,论文提供了理解MaxLogit爆炸等现象的一个比较独特的视角,它能解释一些事情,但无法覆盖全貌,也留下了很多值得思考的地方(吐槽点)。

首先,论文对Attention偏差的分析依赖于$\boldsymbol{V}$的各向异性,这也许可以解释为什么第2层Attention才出现MaxLogit爆炸等异常:因为第1层Attention的输入是Embedding,它相对来说还没那么容易出现各向异性;而第2层及以后的Attention的输入经过了前面的Attention,可能会固有地存在各向异性(参考)。

不过,这无法解释为什么MaxLogit爆炸只在个别层出现,比如论文的实验现象是只有第2层出问题,而K2的结果是2~4层出问题。同样地,这显然也无法解释为啥Muon比Adam更容易出现MaxLogit爆炸(出自Moonlight、K2)。所以,这应该是架构、优化器和低精度等多方面因素的综合结果,单看精度问题是不完整的。

此外,还有一个值得深思的问题是因果关系。论文的Attention偏差的另一个产生条件是注意力集中在少数几个Token上,此时对Attention计算进行干预,成功防止了它的后续异常。然而,笔者观察了一个正常训练的小模型,它的注意力没有想象中那么集中,比如平均Top-1的平均概率不到0.2、Top-400的累积概率才能达到0.9(训练长度4096)。

所以,Attention偏差究竟是训练崩溃的“因”还是“果”?换言之,当出现“注意力集中在少数几个Token上”时,有没有可能说明模型已经进入崩溃范围内了?这时候才进行干预,会不会“为时已晚”?比如虽然在指标上是防止了一些异常,但有没有可能模型已经没法Scale下去了?这些暂时都不得而知。

文章小结 #

本文分享了一篇关于低精度Attention计算偏差的分析论文,同时借着这个机会,给自己补习了一下低精度计算的基础内容。

转载到请包括本文地址:https://www.kexue.fm/archives/11371

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Oct. 27, 2025). 《低精度Attention可能存在有偏的舍入误差 》[Blog post]. Retrieved from https://www.kexue.fm/archives/11371

@online{kexuefm-11371,
        title={低精度Attention可能存在有偏的舍入误差},
        author={苏剑林},
        year={2025},
        month={Oct},
        url={\url{https://www.kexue.fm/archives/11371}},
}