脑洞大开:非线性RNN居然也可以并行计算?
By 苏剑林 | 2023-09-26 | 43487位读者 |近年来,线性RNN由于其可并行训练以及常数推理成本等特性,吸引了一定研究人员的关注(例如笔者之前写的《Google新作试图“复活”RNN:RNN能否再次辉煌?》),这让RNN在Transformer遍地开花的潮流中仍有“一席之地”。然而,目前看来这“一席之地”只属于线性RNN,因为非线性RNN无法高效地并行训练,所以在架构之争中是“心有余而力不足”。
不过,一篇名为《Parallelizing Non-Linear Sequential Models over the Sequence Length》的论文有不同的看法,它提出了一种迭代算法,宣传可以实现非线性RNN的并行训练!真有如此神奇?接下来我们一探究竟。
求不动点 #
原论文对其方法做了非常一般的介绍,而且其侧重点是PDE和ODE,这里我们直接从RNN入手。考虑常见的简单非线性RNN:
\begin{equation}x_t = \tanh(Ax_{t-1} + u_t)\label{eq:rnn}\end{equation}
由于$\tanh$的存在,它只能串行计算。现在我们在两边都减去$Ax_{t-1}$:
\begin{equation}x_t - Ax_{t-1} = \tanh(Ax_{t-1} + u_t) - Ax_{t-1}\end{equation}
当然,这改变不了它是非线性RNN的实质。然而我们可以发现,假如右端的$x_{t-1}$换成像$u_t$那样的给定向量,那么这就是一个线性RNN了,根据《Google新作试图“复活”RNN:RNN能否再次辉煌?》的结果,它是可以并行计算的。此时,敏捷的读者可能已经猜到后面的步骤了——迭代求解!
首先,将上述RNN更改成
\begin{equation}x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)}\label{eq:rnn-iter}\end{equation}
从给定$x_t^{(0)}$出发,反复迭代上式,理想情况下,它会收敛于一个不动点$x_t^*$,这就是原来非线性RNN的计算结果。当然,理论上通过式$\eqref{eq:rnn-iter}$迭代的总计算量是比直接通过式$\eqref{eq:rnn}$递归计算要大的,但由于每一步迭代都是可并行的线性RNN,并且如果收敛速度比较快时迭代步数不需要太多,那么总的耗时通常都会快于直接非线性RNN递归(尤其是序列长度很大时)。
简化形式 #
事实上,非线性RNN之所以慢,无法并行计算还是次要的,最关键是它包含了大量的非element-wise运算,比如式$\eqref{eq:rnn}$的$\tanh$里边的矩阵运算$Ax_{t-1}$;而线性RNN之所以快,除了它允许并行训练之外,更关键的是它能通过对角化来将矩阵乘法变换为element-wise的乘法——对于element-wise乘法来说,即便是串行计算也不会太慢。
当我们通过式$\eqref{eq:rnn-iter}$将非线性RNN转为线性RNN的迭代之后,同样享受线性RNN可对角化的“待遇”,从而提高计算速度。具体来说,在复数域中将$A$对角化为$P\Lambda P^{-1}$,那么式$\eqref{eq:rnn-iter}$变为
\begin{equation}x_t^{(n)} - P\Lambda P^{-1} x_{t-1}^{(n)} = \tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - P\Lambda P^{-1} x_{t-1}^{(n-1)}\end{equation}
两端都左乘$P^{-1}$:
\begin{equation}P^{-1} x_t^{(n)} - \Lambda P^{-1} x_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - \Lambda P^{-1} x_{t-1}^{(n-1)}\end{equation}
令$y_t = P^{-1} x_t$,那么上式可以简化为
\begin{equation}y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)}\end{equation}
由于RNN之后一般都还要接个投影层,所以$x_t = P y_t$的$P$原则上可以合并到外接的投影层里边,也就是说,上式理论上具备跟原来的$\eqref{eq:rnn}$具备同等的表达能力,但由于$\Lambda$是对角阵,递归的计算量会明显降低。上式还出现了逆矩阵$P^{-1}$,不单计算量大,而且不利于优化,所以我们可以干脆将$P^{-1}$和$P\Lambda$换成两个不相关的参数矩阵:
\begin{equation}y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P\tanh(Q y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)}\end{equation}
只要初始化是$PQ=\Lambda$就行。
摄动思想 #
假定$x_t^{(0)}=0$,那么式$\eqref{eq:rnn-iter}$其实就是将原本的非线性RNN就分解为一系列线性RNN:
\begin{equation}\begin{array}{c}
x_t^{(1)} - Ax_{t-1}^{(1)} = \tanh(u_t)\\
x_t^{(2)} - Ax_{t-1}^{(2)} = \tanh(Ax_{t-1}^{(1)} + u_t) - Ax_{t-1}^{(1)} \\
\vdots \\
x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)} \\
\vdots \\
\end{array}\label{eq:rnns}\end{equation}
而假设$x_{t-1},u_t$都是小量,那么对式$\eqref{eq:rnn}$右端利用$\tanh x \approx x$得到:
\begin{equation}x_t = \tanh(Ax_{t-1} + u_t) \approx Ax_{t-1} + u_t \approx Ax_{t-1} + \tanh(u_t)\label{eq:rnn-approx}\end{equation}
这正好是$\eqref{eq:rnns}$中的第一个方程,因此如果假设成立,那么$x_t^{(1)}$或许已经足够接近理想的$x_t^*$,后面的每一步迭代都在快速逼近它。从这里我们可以看出,“两边同时减去$Ax_{t-1}$”是关键之处,这使得$\eqref{eq:rnn-iter}$的第一步迭代就接近于原本非线性RNN的一阶线性近似,这可以提高收敛速度,也是数学物理中的经典操作,名曰“摄动”。
加快收敛 #
根据摄动法的思想,提高收敛速度的关键就是提高近似展开的精度,比如较为简单的改进是只假设$x_{t-1}$是小量,那么根据一阶泰勒展开有(将$u_t$视为列向量,这里的$\circ$是Hadamard积分)
\begin{equation}x_t = \tanh(Ax_{t-1} + u_t) \approx \tanh(u_t) + (\text{sech}^2 u_t\circ A)x_{t-1}\end{equation}
于是改进的结果就是式$\eqref{eq:rnn-iter}$变为
\begin{equation}x_t^{(n)} - A_t x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t x_{t-1}^{(n-1)}\label{eq:iter-plus1}\end{equation}
其中$A_t = \text{sech}^2 u_t\circ A$。更精细的改进是在每一步迭代时,都在前一步迭代结果的基础上进行展开:
\begin{equation}\begin{aligned}
x_t =&\, \tanh(Ax_{t-1} + u_t) \\
\approx&\, \tanh(Ax_{t-1}^{(n-1)} + u_t) + (\text{sech}^2 (Ax_{t-1}^{(n-1)} + u_t)\circ A)(x_{t-1} - x_{t-1}^{(n-1)})
\end{aligned}\end{equation}
于是式$\eqref{eq:rnn-iter}$变为
\begin{equation}x_t^{(n)} - A_t^{(n)} x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t^{(n)} x_{t-1}^{(n-1)}\label{eq:iter-plus2}\end{equation}
其中$A_t^{(n)}=\text{sech}^2 (Ax_{t-1}^{(n-1)} + u_t)\circ A$。最后的这个迭代格式,实际上就是求方程数值解的“牛顿法”,它具有二次收敛速度。
何必收敛 #
理论上来说,$\eqref{eq:iter-plus1}$、$\eqref{eq:iter-plus2}$两个改进确实能提高收敛速度,然而它们使得每一步线性递归的矩阵$A$变得跟$t$甚至$n$有关了,这其实会大大增加并行的复杂度,也不能利用“简化形式”一节的对角化技巧来加速。另一方面,如果保持$\eqref{eq:rnn-iter}$这样的迭代格式,虽然有诸多效率上的好处,但收敛方面确实无法得到很好的保障。
难道这两者的矛盾就无法调和了吗?事实上,按照笔者的观点,最直接的做法是“别去管它”——借助非线性RNN导出了$\eqref{eq:rnn-iter}$后,就忘记原本的非线性RNN,将式$\eqref{eq:rnn-iter}$作为基本模型。也就是说,何必忧虑式$\eqref{eq:rnn-iter}$会不会收敛到原来的非线性RNN?直接将它作为新的出发点不好吗?梯度下降学到怎样的结果就是怎样的结果,如果梯度下降学到的结果是不收敛到原来的非线性RNN,那么就意味着不收敛到原来的RNN是更适合的。
抛开这一层思维束缚后,其实很多问题会变得豁然开朗起来。首先,即便是式$\eqref{eq:iter-plus2}$在理论上拥有非常好的收敛速度,但也是有条件的,而且在深度学习的背景下,要保证这些条件会显得很奢侈。换言之,即便是式$\eqref{eq:iter-plus2}$的收敛性也没有绝对保证,所以何必“五十步笑百步”去苛责式$\eqref{eq:rnn-iter}$?其次,将式$\eqref{eq:rnn-iter}$视为新的出发点后,我们可以将它单纯地理解为线性RNN的一种新用法,或者说解决线性RNN缺陷(比如线性RNN不是图灵完备的)的一个思路,这样操作性更强。
总的来说,不去管它的收敛性,似乎更能打破思维僵局,探索更一般的结果。
一般情形 #
前面的“长篇大论”,都只围绕着简单的非线性RNN也就是式$\eqref{eq:rnn}$进行讨论,对于更常用的LSTM、GRU,结果又如何呢?
以GRU为例,它原本的形式为
\begin{equation}\begin{aligned} z_{t} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1} + b_{z} \right) \\
r_{t} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1} + b_{r} \right) \\
\hat{h}_t & = \tanh \left( W_{h} x_{t} + U_{h} (r_t \circ h_{t - 1}) + b_{c} \right)\\
h_{t} & = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \end{aligned}\end{equation}
初始阶段,所有门控都可以近似视为$\frac{1}{2}$,那么模仿式$\eqref{eq:rnn-approx}$有
\begin{equation}\begin{aligned}
h_{t} &\, = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \\
&\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \hat{h}_t \\
&\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \left(\tanh ( W_{h} x_{t} + b_{c} ) + \frac{1}{2}U_{h} h_{t - 1}\right) \\
&\, = \frac{1}{2} \left(I + \frac{1}{2}U_{h}\right)h_{t - 1} + \frac{1}{2} \tanh ( W_{h} x_{t} + b_{c} ) \\
\end{aligned}\end{equation}
所以可以选取$A=\frac{1}{2} \left(I + \frac{1}{2}U_{h}\right)$,将GRU改写为迭代
\begin{equation}\begin{aligned} z_{t}^{(n)} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1}^{(n-1)} + b_{z} \right) \\
r_{t}^{(n)} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1}^{(n-1)} + b_{r} \right) \\
\hat{h}_t^{(n)} & = \tanh \left( W_{h} x_{t} + U_{h} (r_t^{(n)} \circ h_{t - 1}^{(n-1)}) + b_{c} \right)\\
h_{t}^{(n)} & = Ah_{t-1}^{(n)} - Ah_{t-1}^{(n - 1)} + \left(1 - z_{t}^{(n)}\right) \circ h_{t - 1}^{(n-1)} + z_{t}^{(n)} \circ \hat{h}_t^{(n)} \end{aligned}\end{equation}
总的来说,这种将非线性RNN变为线性RNN迭代的转换,从实践的角度来看,就是以非线性RNN为引,导出一种多层线性RNN的参数共享和组合方法,迭代了几次,那么就有几层线性RNN的计算量。这样自然而言就引发了一个思考:除非可以证明GRU、LSTM等非线性RNN有绝对的优势,否则直接叠加几层“线性RNN+MLP”不好吗?
文章小结 #
本文简单探讨了非线性RNN的并行计算问题——通过数学物理中的“摄动”思想,我们可以将非线性RNN转化为线性RNN的迭代,从而利用线性RNN的可并行性来实现非线性RNN的并行。
转载到请包括本文地址:https://www.kexue.fm/archives/9783
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Sep. 26, 2023). 《脑洞大开:非线性RNN居然也可以并行计算? 》[Blog post]. Retrieved from https://www.kexue.fm/archives/9783
@online{kexuefm-9783,
title={脑洞大开:非线性RNN居然也可以并行计算?},
author={苏剑林},
year={2023},
month={Sep},
url={\url{https://www.kexue.fm/archives/9783}},
}
September 26th, 2023
Google 最近有提出 TSMixer (https://openreview.net/forum?id=wbpxTuXgm0)。
扫了一眼,似乎不是特别有意思。
September 26th, 2023
最近這幾年有一個主張是透過 dynamical system 把 RNN 和 ODE/PDE/SDE 聯繫在一起。
您好,具体哪些 dynamical system? 方便列一些吗
这个联系比较自然吧。就本文提到的论文而言,它主要就是做ODE/PDE的,看上去做RNN只是顺便。
September 26th, 2023
[...]Read More [...]
September 29th, 2023
爲甚麼 公式 (7) 的中只需要初始化保證 $\mathbf{P}\mathbf{Q} = \mathbf{\Lambda}$ 就可以了? Neural Networks 在訓練的過程中,weight tensor 不是會一直改變的嗎?$\mathbf{P}\mathbf{Q} = \mathbf{\Lambda}$ 在訓練過程中可能會無法保證也沒關係?
如果训练过程不保证$\boldsymbol{P}\boldsymbol{Q} = \boldsymbol{\Lambda}$,可能是因为优化器觉得$\boldsymbol{P}\boldsymbol{Q} \neq \boldsymbol{\Lambda}$更好,所以并没有什么关系。
October 10th, 2023
hello,从感受野的角度来看,迭代的过程,$x_t$要看到$x_0$,不是整个不动点迭代要跑t个step吗,这和原来的rnn比,为什么能快呢?
可以用 prefix sum 來並行運算 RNN,這樣就比原先衹能迭代運算 RNN 要快很多。
线性RNN是可以并行的(参考 https://kexue.fm/archives/9554 ),理想情况下复杂度是$\mathcal{O}(\log t)$,所以总的复杂度再乘以迭代步数,因此时间上有可能少于直接计算非线性RNN。
October 18th, 2023
这个(16)是怎么推导出来的,有大神能解释一下细节吗
所有的中间变量都用上一步的迭代结果计算,然后用两端减去一阶近似的方式去构造最终的迭代格式,保证$^{(n-1)}$替换为$^{(n)}$时,退化为原本的非线性RNN。
November 12th, 2023
令 $\mathbf{P}$ 爲 learnable orthogonal matrix 即可,比如可以用 expRNN 裡提出的方法。
这样计算复杂度又更大了?
還有一個辦法,在公式 (3) 中,用 $\mathbf{A} = \mathbf{F}\mathbf{\Lambda}\mathbf{F}^{\ast}$ 替代 $\mathbf{A} = \mathbf{P}\mathbf{\Lambda}\mathbf{P}^{-1}$,其中 $\mathbf{F}$ 是 DFT 矩陣。從 signal processing 的角度來看,$\mathbf{P}^{-1}$ 和 $\mathbf{P}$ 的作用是對 $\mathbf{x}_{t-1}$ 做 discrete orthogonal transform 和 inverse transform。用 DFT 矩陣替代之後,可以用 FFT 來加速運算。
这样的话感觉干脆换用基于FFT的模型得了(捂脸)~
說不定 FFT based Linear RNN 確實可以
還真有這樣的模型(https://openreview.net/forum?id=IxmWsm4xrua)
这个就是广义的FFT呀,并不是反过来用FFT来加速RNN
不會。我最近在試類似的方法,時間複雜度和空間複雜度僅比O(NlogN)大一點點。
December 7th, 2023
有个不太懂的是,这种技术可以用于RNN的并行推理吗?还是只能用于并行训练
理论上可以,但实际上有点难,因为无法提前预知结果长度。
September 2nd, 2024
想问一下以(8)式为例,从step=n-1到step=n,右式的$x_{t-1}^{n-1}$如何得到呢?前一个式子得到的是$x_{t}^{n-1}-Ax_{t-1}^{n-1}$,如果想要单独求出$x_{t-1}^{n-1}$是不是还需要类似于递推或者开一个大数组从$x_{1}^{n-1}$开始往后一个一个计算+存储?
如果是这样的话,对序列的初始值$x_1$,(9)式有:$x_1^{0}=tanh(u_1)$(因为$x_0 \triangleq 0$),那就意味着$x_1$是固定且无损的,不会随着迭代而更新,导致后面所有的$x_{t}^{n}$都无法更新了?
第一步,$x_t^{(1)} - Ax_{t-1}^{(1)} = \tanh(u_t)$,这就是个线性RNN,可以并行求出$x_1^{(1)},x_2^{(1)},x_3^{(1)},\cdots$;
第二步,$x_t^{(2)} - Ax_{t-1}^{(2)} = \tanh(Ax_{t-1}^{(1)} + u_t) - Ax_{t-1}^{(1)}$,其中所有的$x_1^{(1)},x_2^{(1)},x_3^{(1)},\cdots$已经在上一步求出,因此也可以并行求出$x_1^{(2)},x_2^{(2)},x_3^{(2)},\cdots$;
依此类推~
嗯嗯,这个逻辑乍一看很显然,但是感觉细想还是不太对。$n$迭代过程中$x^{(n)}$能够产生更新的原因还是因为真值在随着$n$的迭代的过程中在跟着$t$进行传递吧,这样理论上$n=t$才可以收敛?
第一步$x_1^{(1)}=tanh(u_1)$已经是真值$x_1$了,而且这个真值每一步迭代都是这样的形式,不会改变($x_1$是首项,即$x_0 \triangleq 0$,因此$x_1$的递推式和后面的不同);第二步这个真值根据$x_2^{(2)}-Ax_1^{(2)}=tanh(Ax_1^{(1)}+u_2)-Ax_1^{(1)}$,由于$x_1^{(1)}=x_1^{(2)}$,所以$x_2^{(2)}=tanh(Ax_1^{(1)}+u_2)=tanh(Ax_1+u_2)$,也即正确的$x_1$在迭代第二步$n=2$的时候传递到了$x_2$处,使得$x_2$也获得了真值。
之后类似,整体的迭代的收敛性其实是因为真值在随着$n$根据$t$来逐步往后传的,如果真值传不到($n \ll t_{length}$),误差会很大吧?
所以你的意思是想要表达这样做的误差可能很大,以及你为什么觉得误差大的一些看法?
如果这样的话,那这个问题确实无解了,要严格保证它收敛于原始非线性RNN还是要静心调试的,我更倾向于将它当成一种复杂线性RNN的构建方式,事后就当它线性RNN来用,保持训练和预测的一致性。