低精度Attention可能存在有偏的舍入误差
By 苏剑林 | 2025-10-27 | 21728位读者 |前段时间笔者在arXiv上刷到了论文《Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention》,里面描述的实验现象跟我们在训练Kimi K2时出现的一些现象很吻合,比如都是第二层Attention开始出现问题。论文将其归因为低精度Attention固有的有偏误差,这个分析角度是比较出乎笔者意料的,所以饶有兴致地阅读了一番。
然而,论文的表述似乎比较让人费解——当然也有笔者本就不大熟悉低精度运算的原因。总之,经过多次向作者请教后,笔者才勉强看懂论文,遂将自己的理解记录在此,供大家参考。
结论简述 #
要指出的是,论文标题虽然点名了“Flash Attention”,但按照论文的描述,即便block_size取到训练长度那么大,相同的问题依然会出现,所以Flash Attention的分块计算并不是引起问题的原因,因此我们可以按照朴素的低精度Attention实现来简化分析。
简单起见,我们只分析单头Attention,设$\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}\in\mathbb{R}^{n\times d}$,记$\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^{\top}$,加粗的$\boldsymbol{1}$是指$n\times 1$的全1矩阵,$\boldsymbol{S}_{\max}$则指$\boldsymbol{S}$每行取最大值后得到的$n\times 1$矩阵,那么
\begin{equation}\boldsymbol{O} = \frac{\exp(\boldsymbol{S})\boldsymbol{V}}{\exp(\boldsymbol{S})\boldsymbol{1}} = \frac{\exp(\boldsymbol{S} - \boldsymbol{S}_{\max})\boldsymbol{V}}{\exp(\boldsymbol{S}- \boldsymbol{S}_{\max})\boldsymbol{1}}\end{equation}
我们记$\bar{\boldsymbol{P}} = \exp(\boldsymbol{S} - \boldsymbol{S}_{\max})$,那么Attention的关键计算是矩阵乘法$\bar{\boldsymbol{P}}\boldsymbol{V}$,它一般是在BF16精度下进行。论文给出的结论是:在低精度计算下,$\bar{\boldsymbol{P}}\boldsymbol{V}$这一步存在有偏的舍入误差。也就是说,在长期平均下,低精度计算的$\bar{\boldsymbol{P}}\boldsymbol{V}$跟准确值的差的期望并不是零。
这样一来,不同训练步骤之间的偏差可能就会持续累积,从而引起MaxLogit爆炸、Loss Spike等问题,直至训练崩溃。当然,严格来讲这只能算是MaxLogit爆炸等问题的一种可能的产生机制,不一定是全部,但即便如此,也值得我们学习和思考一番。
向偶舍入 #
为了理解论文结论,我们先来补习一些关于舍入误差的基本常识。之所以会写这一节,原因开头就说了——笔者本身并不熟悉低精度运算——所以这一节完全是写给自己补基础的,对此已有了解的读者完全可以略过。
我们知道,常用的舍入(Round)方式是“四舍五入”:在10进制中,一个正的1位小数要舍去最后一位,0~4就会变成0,产生的误差是$0,-0.1,-0.2,-0.3,-0.4$;5~9就会变成10,产生的误差是$0.5,0.4,0.3,0.2,0.1$。不知道大家发现没,这些误差的平均值并不是0,而是0.05,即“四舍五入”平均而言会放大原来的数字,产生正的偏差。
当然,相对偏差会随着舍去位数的增加而减少,比如一个2位小数要舍去2个小数位,平均误差则是0.005。但不论如何,四舍五入的这个正偏差总是存在的,只不过是大小不同。偏差的根源在中间点,比如0.51和0.49,它们分别往上/往下舍入,误差刚好抵消,但对于0.50不管规定它往上舍入还是往下舍入,都没有另一个数跟它抵消误差。
为了消除偏差,IEEE 754 提出了“向偶舍入(Round-to-Even)”原则,它规定对于中间情形,应该按照靠近偶数的方向舍入,比如2.5舍去最后一位要变成2,但3.5舍去最后一位则变成4,这样“5”就各有一半的几率产生$\pm 5$的误差,平均误差变为零,从而消除了偏差。
回到计算机领域。我们知道计算机使用二进制,它只有0和1,那么1就起到了10进制的“5”的角色。二进制中“四舍五入”的偏差更形象,因为末位只能是0或1:如果是0,自然不用改变,而如果是1,则触发“五入”而进1。所以,二进制数按“四舍五入”舍去末位,结果必然大于或等于原数,因此也需要“向偶舍入”来消除偏差。
BF16加法 #
接着我们重温一下BF16格式。BF16用16位二进制表示一个浮点数,其中1位符号、7位尾数、8位指数,8位指数的设计让它表示的范围跟FP32(1位符号、23位尾数、8位指数)一致,这也使它成为如今LLM训练的主要浮点格式。
BF16保留了较多的指数位,代价必然是尾数较少,从而能表示精度较低。为了缓解低精度带来累积误差,BF16运算采取的策略是“FP32累加”,也就是说BF16数的累加都是先转换成FP32,然后在FP32空间中相加得到FP32的结果,最后再转回BF16的。
现在我们考虑两个符号和指数相同的BF16数字相加。为什么要选指数相同来分析呢?因为我们要估计误差,指数相同意味着这两个数同数量级,相加后最有可能产生最大的误差。举个例子,如果两个数相加的数相差100倍,那么我哪怕直接返回最大者,误差也不过1%,所以最大误差往往在同数量级的数相加时发生。
两个符号和指数相同的BF16数字相加,必然会出现进位,比如“1.0000001 + 1.0000100 = 10.0000101 = 1.00000101 × 10”,这时候需要指数加1,并且舍去最后一位1,才能转换成BF16格式。如上一节所述,如果按照“四舍五入”舍去末位,那么将会产生正的偏差。不过我们已经知道,科学家早已发现了这个偏差,因此提出了“向偶舍入”来消除偏差。
两大一小 #
所以,到目前为止,一切结果都在可控和预期的范围内,还没有偏差产生。然而,不出意外的话,意外出现了。
现在让我们考虑三个同符号的数相加,这三个数的特点是:其中两个数指数相同且很大,第三个数很小。比如我们在上一节的例子“1.0000001 + 1.0000100”基础上再加上“0.0000000001”,那么得到“1.0000001 + 1.0000100 + 0.0000000001= 10.0000101001 = 1.00000101001 × 10”。
原本两个数相加,结果是“1.00000101 × 10”,舍去末位时会触发“向偶舍入”,得到“1.0000010 × 10”,可现在多了一个极小数,转换成BF16时要舍去的尾数变成了“1001”,比中间点更大,所以触发向上舍入原则,结果是“1.0000011 × 10”。那么在原本两个数相加的视角看来,第三个极小数的出现,破坏了“向偶舍入”规则,使得正偏差再次出现!
当然,这种情况出现条件看上去还是很苛刻的。首先三个数需要同号,其次需要满足“两大一小”,其中两个大数刚好能触发进位,然后小数小到只能影响FP32的尾数(即第9~23位尾数)。这样一来,小数很小,本身舍去都没多大误差,但它的存在,偏偏刚好能破坏了两个大数的“向偶舍入”规则,从而带来了单侧的偏差。
量身定制 #
这么苛刻的条件,实际中真的能出现吗?一般情况情况下还真不容易,但对于Attention来说,这仿佛就是“量身定制”的Bug!
我们取出$\bar{\boldsymbol{P}}\boldsymbol{V}$的某行某列(也就是某个元素),它可以写成
\begin{equation}\sum_{i=1}^n \bar{p}_i v_i \label{eq:sum-pi-vi}\end{equation}
其中$\bar{p}_i = \exp(s_i - \max(s_i))\leq 1$。我们知道,Softmax Attention的特点是能够“集中注意力”,也就是说注意力可能会集中在有限几个Token上,体现在$\bar{p}_i$上就是少数几个Token的$\bar{p}_i$接近于1,剩下的则会非常接近于0,但由于$\exp$的缘故,无法精确等于0(除非下溢出BF16的表示空间)。
然后,随着层数的堆叠和训练的进行,输入$\boldsymbol{V}$可能会出现“各向异性”,其中一种表现是某些维度的正负号分布不均匀,不失一般性,我们假设$v_i$大部分都是正数(负数同理),并且数量级大致相等。那么,求和$\eqref{eq:sum-pi-vi}$可以分为两部分:少数几个能接近于1的$\bar{p}_i$跟$v_i$相乘,成为求和的主项,剩下的余项是大部分接近于0的$\bar{p}_i$与$v_i$相乘。
论文考虑了一个特例:主项对应的几个$\bar{p}_i$并不是接近于1,而是都等于1,也就是$\boldsymbol{S}$的某些行同时存在多个$\max$。这个特例自然更难成立,但更容易理解,此时主项的$\bar{p}_i v_i$固有精度只有BF16。如此一来,“天时地利”俱备,完美触发了上一节说的Bug:
大部分项都是正数,主项精度都是BF16,求和满足进位条件;剩下余项极小,只能影响FP32最末端的尾数,刚好破坏了“向偶舍入”导致偏差;最后,由于“集中注意力”,主项的数目不会多,所以进位也不会太多(舍去位数越多,偏差越小),使得偏差处于显著区间!
这一套组合下来,可不就是为Attention定制的“专属Bug”?
干掉余项 #
了解问题的来龙去脉后,我们再来思考一下怎么解决问题。
表面上看,引发偏差的原因是极小的余项破坏了“向偶舍入”,但更深入思考一下,其实根本原因是“四舍五入”这个规则在中间处存在一个突变点,在突变点附近容易因为扰动而产生偏差,“向偶舍入”虽然能消除偏差,但消除不了突变点。理想的根治办法是Stochastic Rounding,也就是依概率向上/向下舍入,这样最大程度上避免了小扰动带来的偏差。
然而,据说Stochastic Rounding不容易有高效的硬件级实现,所以现在多数硬件的矩阵乘法算子都不带Stochastic Rounding。因此,原论文选择了另一条路径,直接面对问题,其思路笔者称为“干掉余项”。具体来说,在检测到某个触发条件时,我们将Attention的计算公式改为
\begin{equation}\boldsymbol{O} = \frac{\exp(\boldsymbol{S})\boldsymbol{V}}{\exp(\boldsymbol{S})\boldsymbol{1}} = \frac{\exp(\boldsymbol{S} - \beta\boldsymbol{S}_{\max})\boldsymbol{V}}{\exp(\boldsymbol{S}- \beta\boldsymbol{S}_{\max})\boldsymbol{1}}\end{equation}
其中$\beta > 1$。这样一来,每一项都需要多除以$\exp((\beta-1)\boldsymbol{S}_{\max})$,这是一个并不算小的数(论文设置$\beta \geq 2$),于是原本就极小的余项,就容易下溢至零而消失,那么“向偶舍入”便重新发挥作用,从而消除偏差。
那么,检测条件是什么呢?原论文考虑得比较简单,就是矩阵$\boldsymbol{S}$的行出现大于等于两次最大值时,修改就会触发,此时$\bar{p}_i$中至少有两个1。但笔者认为这里肯定有很大调整空间的,算是留下了一个改进方向吧。另外要注意的是,Flash Attention是分Block计算的,所以这个检测条件和修改也是按Block进行,细节可以参考原论文附录的代码。
延伸思考 #
总的来说,论文提供了理解MaxLogit爆炸等现象的一个比较独特的视角,它能解释一些事情,但无法覆盖全貌,也留下了很多值得思考的地方(吐槽点)。
首先,论文对Attention偏差的分析依赖于$\boldsymbol{V}$的各向异性,这也许可以解释为什么第2层Attention才出现MaxLogit爆炸等异常:因为第1层Attention的输入是Embedding,它相对来说还没那么容易出现各向异性;而第2层及以后的Attention的输入经过了前面的Attention,可能会固有地存在各向异性(参考)。
不过,这无法解释为什么MaxLogit爆炸只在个别层出现,比如论文的实验现象是只有第2层出问题,而K2的结果是2~4层出问题。同样地,这显然也无法解释为啥Muon比Adam更容易出现MaxLogit爆炸(出自Moonlight、K2)。所以,这应该是架构、优化器和低精度等多方面因素的综合结果,单看精度问题是不完整的。
此外,还有一个值得深思的问题是因果关系。论文的Attention偏差的另一个产生条件是注意力集中在少数几个Token上,此时对Attention计算进行干预,成功防止了它的后续异常。然而,笔者观察了一个正常训练的小模型,它的注意力没有想象中那么集中,比如平均Top-1的平均概率不到0.2、Top-400的累积概率才能达到0.9(训练长度4096)。
所以,Attention偏差究竟是训练崩溃的“因”还是“果”?换言之,当出现“注意力集中在少数几个Token上”时,有没有可能说明模型已经进入崩溃范围内了?这时候才进行干预,会不会“为时已晚”?比如虽然在指标上是防止了一些异常,但有没有可能模型已经没法Scale下去了?这些暂时都不得而知。
文章小结 #
本文分享了一篇关于低精度Attention计算偏差的分析论文,同时借着这个机会,给自己补习了一下低精度计算的基础内容。
转载到请包括本文地址:https://www.kexue.fm/archives/11371
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Oct. 27, 2025). 《低精度Attention可能存在有偏的舍入误差 》[Blog post]. Retrieved from https://www.kexue.fm/archives/11371
@online{kexuefm-11371,
title={低精度Attention可能存在有偏的舍入误差},
author={苏剑林},
year={2025},
month={Oct},
url={\url{https://www.kexue.fm/archives/11371}},
}










October 27th, 2025
苏神好,
两大一小相加产生误差至少需要满足以下条件:
1. 每一次相加都先提升到fp32相加后下降bf16
2. flash-attn kernel 使用以上混合精度机制
有一个疑问:如果不是,那单纯的bf16精度相加,过小的数值本身就是0了吧?
我没有仔细去看过flash-attn的cuda实现,但似乎有几个issue提到fa目前不支持fp32或者混合精度。(待考证
确实, 之前也看到了这篇文章, 虽然切入点很有意思, 但是在 flash attention 中 $pv$ 结果是 fp32 的, 并且也确实是 fp32 累加. 实际训练中应该并不是 $\sum pv$ 的精度问题.
对于 flash attention 而言, 不断 rescale O 再加上 pv 可能有更大的精度影响.
统一回复 @jerrick|comment-28687和@PengleZhang|comment-28688:
本文的解读,以及原论文的分析,都是考虑了bf16在fp32空间中累加后所出现的偏差,并不是说加一个数就cast一次到bf16的情形。它是在同号假设下,少量几个主项的bf16数,加上大量的极小的bf16数,在fp32空间求和,然后将结果cast成bf16所产生的偏差。
由于求和是在fp32下进行的,所以过小的数值一般也到不了零,但刚刚好破坏了向偶舍入,所以出现偏差,这就是产生偏差的大致逻辑。
感谢回复!
但是感觉这和我的解释没有冲突.
p 和 v 都是 bf16, 但是 pv 在 tensor core 中是 fp32 累加, 最后得到的结果也是 fp32.(bf16 x bf16 -> fp32).
所以 $\sum pv$ 的每一项都是 fp32, 并不是 bf16 recast 到 fp32 然后在 fp32 累加.
可能的误差是 p 从 fp32 转换到 bf16, 以及最终 o 从 fp32 转化到 bf16.
如果我没理解错,$\sum_i \bar{p}_i v_i$的每一项$\bar{p}_i v_i$都是bf16的(乘法是低精度的,加法才是高精度的),只是它们的累加在fp32完成?即等价运算是bf16精度的$\bar{p}_i v_i$,cast到fp32后求和,然后cast到bf16。
$\overline p$ 和 $v_i$ 是 bf16 的, 但是 $\overline p v_i$ 是 fp32. 低精度乘法得到 fp32, 然后进行 fp32 的累加.
如果我们相信 tensorcore 的精度(相信 nvidia), 则可以认为精度与 $\overline p, v_i$ fp32 -> bf16 -> fp32, 然后进行 fp32 的乘法是相同的.
不知道我们是否是一个意思.
我要表达的是:$\bar{p}_i$,$v_i$都是bf16的,$\bar{p}_i v_i$的dtype是fp32的,但它固有精度只有bf16,也就是说它就相当于一个bf16的数cast成fp32的。
当然我是完全的门外行,只是根据只言片语的信息给出的判断:
https://github.com/pytorch/pytorch/issues/146241
https://www.kimi.com/share/d3vj7u6mcu0njtiddlc0
嗯. 固有精度为 bf16 是有问题的.
在 `torch.matmul` 层面确实是 bf16 in bf16 out.
但是在 kernel 内部, 实际的 mma 指令是 bf16 in fp32 out.
我们在使用 cutlass 或者 cublas 时, 需要显式指定内部计算累加的 dtype.
使用 cutlass(cute) 构建的 flash attention 也不例外.
我们可以在 kernel 内调用 mma.sync 检查. c, d 的最后 16 bit 并不是 0.
也因如此, 在 hopper 中 fp8 gemm 有 bug, c, d 是 fp23 精度, 所以 deepseek 才需要自己写 deep gemm kernel 绕过 tensor core 累加.
October 27th, 2025
@PengleZhang|comment-28697
我新开一楼。设两个矩阵$A,B$,$C=AB$,那么
$$C_{i,j} = \sum_k A_{i,k} B_{k,j}$$
==========
1. “在 `torch.matmul` 层面确实是 bf16 in bf16 out.”
2. “但是在 kernel 内部, 实际的 mma 指令是 bf16 in fp32 out”
你这两行,应该都是指$C_{i,j}$的格式是bf16/fp32?还是指单个数乘法即$A_{i,k} B_{k,j}$的运算方式?
==========
“我们在使用 cutlass 或者 cublas 时, 需要显式指定内部计算累加的 dtype.”
这个是指求和$\sum_k$在fp32内完成?
==========
最后你的意思是指单个$A_{i,k} B_{k,j}$的格式是fp32,并且第8~23位尾数也非零是吧。这个我确实是不知道的。不过这样一来,bf16的matmul,不管乘加其实都在fp32内进行了,还有明显的加速空间吗?
不过倒是可以去思考一下,在这个说法下,本文的分析应该怎么修改~
是的, 我们可以理解为, 低精度 tensor core 理论上的误差, 只有 $A,B$ 转换到 bf16 这一步. 其他部分与 fp32 一致.
我们考虑计算一次乘加 $d=ab+c$. 尽管 $ab,c,d$ 都是 fp32, 但是 $ab$ 这一步骤本身是 bf16*bf16. 相较于 fp32*fp32 仍然能够节省 cycle.
tensor core 的主要加速在于大矩阵可以通过脉冲阵列等方法实现高效的硬件并行.
而低精度可以降低乘加中乘法的复杂度, 并且降低数据带宽需求.
文章切入的角度很好. 尽管我们现在都在用 bf16 甚至 fp8 训练, 但是对于单个操作的精度分析远远不足. 之前 meta 有讨论 flash attention的fusion与普通 attention 实现精度误差的文章, 最近也有讨论 RoPE qk 截断到 bf16 精度影响的文章, 但这方面的探索整体还是不足的.
受教了,感谢指点,我沿着修正后的理解继续思考一下。
对于 bf16 tensor core $c=c+ab$ ,其中累加器 $c$ 是 fp32,$a,b$ 都是 bf16 (7 位精度)。
论文所描述的情况是:由于存在两个 $i$ 使得 $p_i = 1$(最大值有两个),且 $v_i$ 只有 7 位精度,因此 $p_i v_i$ 同样只有 7 位精度,因此 $\sum p_iv_i$ 这个 fp32 结果 cast 回 bf16 的误差是有偏的。
抱歉,目前你看到的最新版blog,是我根据@PengleZhang|comment-28699的指点和论文内容重新调整了一下描述,而@PengleZhang|comment-28699的评论则是根据旧版内容,所以你们俩都没问题。
October 28th, 2025
苏神您好!浅问下有啥版本可以获取您的博客的latex版本吗?这样给LLM帮我理解的时候,比单纯复制会方便很多哎~
暂时没有。你可以直接把链接给模型呀,有爬网页能力的AI,可以直接获取到页面源码,就是latex格式。
比如“我们要讨论的文章是《随机矩阵的谱范数的快速估计》,链接是 https://kexue.fm/archives/11335 。请以此为基础,继续回答我后面的问题。”
October 29th, 2025
> 舍去位数越多,偏差越小
请问这句话如何理解呢?舍去位数越多,应该丢失了更多的数值,那偏差应该越大吧?
偏差不是误差,偏差是平均误差与0的差距。“向偶舍入”一节已经举例子了,10进制舍去1位小数的平均误差是0.05,舍去2位小数的平均误差是0.005,所以舍去越多,偏差越小。
October 29th, 2025
苏神您好,这块:现在让我们考虑三个同符号的数相加,这三个数的特点是:其中两个数指数相同且很大,第三个数很小。比如我们在上一节的例子“1.0000001 + 1.0000100”基础上再加上“0.0000000001”,那么得到“1.0000001 + 1.0000100 + 0.0000000001= 10.0000101001 = 1.00000101001 × 10”,这里最后应该是乘2?
2进制,所以是10~
明白!谢谢
November 12th, 2025
"理想的根治办法是Stochastic Rounding,也就是依概率向上/向下舍入,这样最大程度上避免了小扰动带来的偏差。"
苏神想请教下这里,这篇paper好像没有做Stochastic Rounding的相关实验,你的意思是不是如果有高效的硬件指令实现Stochastic Rounding,PV=O用fp32做累加,最终通过Stochastic Rounding将O convet成bf16,那paper中遇到的精度问题就可以解决了?
理论上是的,Stochastic Rounding后它就不会因为微小的、实际上几乎可以忽略的小项而产生有偏的误差了。
好的,感谢!