深度学习的互信息:无监督提取特征
By 苏剑林 | 2018-10-02 | 259155位读者 |对于NLP来说,互信息是一个非常重要的指标,它衡量了两个东西的本质相关性。本博客中也多次讨论过互信息,而我也对各种利用互信息的文章颇感兴趣。前几天在机器之心上看到了最近提出来的Deep INFOMAX模型,用最大化互信息来对图像做无监督学习,自然也颇感兴趣,研读了一番,就得到了本文。
本文整体思路源于Deep INFOMAX的原始论文,但并没有照搬原始模型,而是按照这自己的想法改动了模型(主要是先验分布部分),并且会在相应的位置进行注明。
我们要做什么 #
自编码器 #
特征提取是无监督学习中很重要且很基本的一项任务,常见形式是训练一个编码器将原始数据集编码为一个固定长度的向量。自然地,我们对这个编码器的基本要求是:保留原始数据的(尽可能多的)重要信息。
我们怎么知道编码向量保留了重要信息呢?一个很自然的想法是这个编码向量应该也要能还原出原始图片出来,所以我们还训练一个解码器,试图重构原图片,最后的loss就是原始图片和重构图片的mse。这导致了标准的自编码器的设计。后来,我们还希望编码向量的分布尽量能接近高斯分布,这就导致了变分自编码器。
重构的思考 #
然而,值得思考的是“重构”这个要求是否合理?
首先,我们可以发现通过低维编码重构原图的结果通常是很模糊的,这可以解释为损失函数mse要求“逐像素”重建过于苛刻。又或者可以理解为,对于图像重构事实上我们并没有非常适合的loss可以选用,最理想的方法是用对抗网络训练一个判别器出来,但是这会进一步增加任务难度。
其次,一个很有趣的事实是:我们大多数人能分辨出很多真假币,但如果要我们画一张百元大钞出来,我相信基本上画得一点都不像。这表明,对于真假币识别这个任务,可以设想我们有了一堆真假币供学习,我们能从中提取很丰富的特征,但是这些特征并不足以重构原图,它只能让我们分辨出这堆纸币的差异。也就是说,对于数据集和任务来说,合理的、充分的特征并不一定能完成图像重构。
最大化互信息 #
互信息 #
上面的讨论表明,重构不是好特征的必要条件。好特征的基本原则应当是“能够从整个数据集中辨别出该样本出来”,也就是说,提取出该样本(最)独特的信息。如何衡量提取出来的信息是该样本独特的呢?我们用“互信息”来衡量。
让我们先引入一些记号,用$X$表示原始图像的集合,用$x\in X$表示某一原始图像,$Z$表示编码向量的集合,$z\in Z$表示某个编码向量,$p(z|x)$表示$x$所产生的编码向量的分布,我们设它为高斯分布,或者简单理解它就是我们想要寻找的编码器。那么可以用互信息来表示$X,Z$的相关性
$$\begin{equation}I(X,Z) = \iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{p(z)}dxdz\label{eq:mi}\end{equation}$$
这里的$\tilde{p}(x)$原始数据的分布,$p(z)$是在$p(z|x)$给定之后整个$Z$的分布,即
$$\begin{equation}p(z) = \int p(z|x)\tilde{p}(x)dx\end{equation}$$
那么一个好的特征编码器,应该要使得互信息尽量地大,即
$$\begin{equation}p(z|x) = \mathop{\text{argmax}}_{p(z|x)} I(X,Z) \end{equation}$$
互信息越大意味着(大部分的)$\log \frac{p(z|x)}{p(z)}$应当尽量大,这意味着$p(z|x)$应当远大于$p(z)$,即对于每个$x$,编码器能找出专属于$x$的那个$z$,使得$p(z|x)$的概率远大于随机的概率$p(z)$。这样一来,我们就有能力只通过$z$就从中分辨出原始样本来。
注意:$\eqref{eq:mi}$的名称为互信息,而对数项$\log \frac{p(z|x)}{p(z)}$我们称为“点互信息”,有时也直接称为互信息。两者的差别是:$\eqref{eq:mi}$计算的是整体关系,比如回答“前后两个词有没有关系”的问题;$\log \frac{p(z|x)}{p(z)}$计算的是局部关系,比如回答“‘忐’和‘忑’是否经常连在一起出现”的问题。
先验分布 #
前面提到,相对于自编码器,变分自编码器同时还希望隐变量服从标准正态分布的先验分布,这有利于使得编码空间更加规整,甚至有利于解耦特征,便于后续学习。因此,在这里我们同样希望加上这个约束。
Deep INFOMAX论文中通过类似AAE的思路通过对抗来加入这个约束,但众所周知对抗是一个最大最小化过程,需要交替训练,不够稳定,也不够简明。这里提供另一种更加端到端的思路:设$q(z)$为标准正态分布,我们去最小化$p(z)$与先验分布$q(z)$的KL散度
$$\begin{equation}\label{eq:prior}KL(p(z)\Vert q(z))=\int p(z)\log \frac{p(z)}{q(z)}dz\end{equation}$$
将$\eqref{eq:mi}$与$\eqref{eq:prior}$加权混合起来,我们可以得到最小化的总目标:
$$\begin{equation}\begin{aligned}p(z|x) =& \min_{p(z|x)} \left\{- I(X,Z) + \lambda KL(p(z)\Vert q(z))\right\}\\
=&\min_{p(z|x)}\left\{-\iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{p(z)}dxdz + \lambda\int p(z)\log \frac{p(z)}{q(z)}dz\right\}\end{aligned}\label{eq:total-loss-1}\end{equation}$$
看起来很清晰很美好,但是我们还不知道$p(z)$的表达式,也就没法算下去了,因此这事还没完。
逐个击破 #
简化先验项 #
有意思的是式$\eqref{eq:total-loss-1}$的loss进行稍加变换得到:
$$\begin{equation}p(z|x) =\min_{p(z|x)}\left\{\iint p(z|x)\tilde{p}(x)\left[-(1+\lambda)\log \frac{p(z|x)}{p(z)} + \lambda \log \frac{p(z|x)}{q(z)}\right]dxdz\right\}\end{equation}$$
注意上式正好是互信息与$\mathbb{E}_{x\sim\tilde{p}(x)}[KL(p(z|x)\Vert q(z))]$的加权求和,而$KL(p(z|x)\Vert q(z))$这一项是可以算出来的(正好是VAE的那一项KL散度),所以我们已经成功地解决了整个loss的一半,可以写为
$$\begin{equation}p(z|x) =\min_{p(z|x)}\left\{-\beta\cdot I(X,Z)+\gamma\cdot \mathbb{E}_{x\sim\tilde{p}(x)}[KL(p(z|x)\Vert q(z))]\right\}\label{eq:total-loss-2}\end{equation}$$
下面我们主攻互信息这一项。
互信息本质 #
现在只剩下了互信息这一项没解决了,怎么才能最大化互信息呢?我们把互信息的定义$\eqref{eq:mi}$稍微变换一下:
$$\begin{equation}\begin{aligned}I(X,Z) =& \iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)\tilde{p}(x)}{p(z)\tilde{p}(x)}dxdz\\
=& KL(p(z|x)\tilde{p}(x)\Vert p(z)\tilde{p}(x))
\end{aligned}\end{equation}$$
这个形式揭示了互信息的本质含义:$p(z|x)\tilde{p}(x)$描述了两个变量$x,z$的联合分布,$p(z)\tilde{p}(x)$则是随机抽取一个$x$和一个$z$时的分布(假设它们两个不相关时),而互信息则是这两个分布的KL散度。而所谓最大化互信息,就是要拉大$p(z|x)\tilde{p}(x)$与$p(z)\tilde{p}(x)$之间的距离。
注意KL散度理论上是无上界的,我们要去最大化一个无上界的量,这件事情有点危险,很可能得到无穷大的结果。所以,为了更有效地优化,我们抓住“最大化互信息就是拉大$p(z|x)\tilde{p}(x)$与$p(z)\tilde{p}(x)$之间的距离”这个特点,我们不用KL散度,而换一个有上界的度量:JS散度(当然理论上也可以换成Hellinger距离,请参考《f-GAN简介:GAN模型的生产车间》),它定义为
$$JS(P,Q) = \frac{1}{2}KL\left(P\left\Vert\frac{P+Q}{2}\right.\right)+\frac{1}{2}KL\left(Q\left\Vert\frac{P+Q}{2}\right.\right)$$
JS散度同样衡量了两个分布的距离,但是它有上界$\frac{1}{2}\log 2$,我们最大化它的时候,同样能起到类似最大化互信息的效果,但是又不用担心无穷大问题。于是我们用下面的目标取代式$\eqref{eq:total-loss-2}$
$$\begin{equation}p(z|x) =\min_{p(z|x)}\left\{-\beta\cdot JS\big(p(z|x)\tilde{p}(x), p(z)\tilde{p}(x)\big)+\gamma\cdot \mathbb{E}_{x\sim\tilde{p}(x)}[KL(p(z|x)\Vert q(z))]\right\}\label{eq:total-loss-3}\end{equation}$$
当然,这并没有改变问题的本质和难度,JS散度也还是没有算出来。下面到了攻关的最后一步。
攻克互信息 #
在文章《f-GAN简介:GAN模型的生产车间》中,我们介绍了一般的$f$散度的局部变分推断(那篇文章的式$(13)$)
$$\begin{equation}\mathcal{D}_f(P\Vert Q) = \max_{T}\Big(\mathbb{E}_{x\sim p(x)}[T(x)]-\mathbb{E}_{x\sim q(x)}[g(T(x))]\Big)\label{eq:f-div-e}\end{equation}$$
对于JS散度,给出的结果是
$$\begin{equation}JS(P,Q) = \max_{T}\Big(\mathbb{E}_{x\sim p(x)}[\log \sigma(T(x))] + \mathbb{E}_{x\sim q(x)}[\log(1-\sigma(T(x))]\Big)\end{equation}$$
代入$p(z|x)\tilde{p}(x), p(z)\tilde{p}(x)$就得到
$$\begin{equation}\begin{aligned}&JS\big(p(z|x)\tilde{p}(x), p(z)\tilde{p}(x)\big)\\=& \max_{T}\Big(\mathbb{E}_{(x,z)\sim p(z|x)\tilde{p}(x)}[\log \sigma(T(x,z))] + \mathbb{E}_{(x,z)\sim p(z)\tilde{p}(x)}[\log(1-\sigma(T(x,z))]\Big)\end{aligned}\label{eq:f-div-e-js}\end{equation}$$
你没看错,除去常数项不算,它就完全等价于deep INFOMAX论文中的式$(5)$。我很奇怪,为什么论文作者放着上面这个好看而直观的形式不用,非得故弄玄虚搞个让人茫然的形式。其实$\eqref{eq:f-div-e-js}$式的含义非常简单,它就是“负采样估计”:引入一个判别网络$\sigma(T(x,z))$,$x$及其对应的$z$视为一个正样本对,$x$及随机抽取的$z$则视为负样本,然后最大化似然函数,等价于最小化交叉熵。
这样一来,通过负采样的方式,我们就给出了估计JS散度的一种方案,从而也就给出了估计JS版互信息的一种方案,从而成功攻克了互信息。现在,对应式$\eqref{eq:total-loss-3}$,具体的loss为
$$\begin{equation}\begin{aligned}&p(z|x),T(x,z) \\
=&\min_{p(z|x),T(x,z)}\Big\{-\beta\cdot\Big(\mathbb{E}_{(x,z)\sim p(z|x)\tilde{p}(x)}[\log \sigma(T(x,z))] + \mathbb{E}_{(x,z)\sim p(z)\tilde{p}(x)}[\log(1-\sigma(T(x,z))]\Big)\\
&\qquad\qquad\qquad+\gamma\cdot \mathbb{E}_{x\sim\tilde{p}(x)}[KL(p(z|x)\Vert q(z))]\Big\}\end{aligned}\label{eq:total-loss-4}\end{equation}$$
现在,理论已经完备了,剩下的就是要付诸实践了。
从全局到局部 #
batch内打乱 #
从实验上来看,式$\eqref{eq:total-loss-4}$就是要怎么操作呢?先验分布的KL散度那一项不难办,照搬VAE的即可。而互信息那一项呢?
首先,我们随机选一张图片$x$,通过编码器就可以得到$z$的均值和方差,然后重参数就可以得到$z_x$,这样的一个$(x, z_x)$对构成一个正样本呢;负样本呢?为了减少计算量,我们直接在batch内对图片进行随机打乱,然后按照随机打乱的顺序作为选择负样本的依据,也就是说,如果$x$是原来batch内的第4张图片,将图片打乱后第4张图片是$\hat{x}$,那么$(x,z_x)$就是正样本,$(\hat{x},z_x)$就是负样本。
局部互信息 #
上面的做法,实际上就是考虑了整张图片之间的关联,但是我们知道,图片的相关性更多体现在局部中(也就是因为这样所以我们才可以对图片使用CNN)。换言之,图片的识别、分类等应该是一个从局部到整体的过程。因此,有必要把“局部互信息”也考虑进来。
通过CNN进行编码的过程一般是:
$$\text{原始图片}x\xrightarrow{\text{多个卷积层}} h\times w\times c\text{的特征} \xrightarrow{\text{卷积和全局池化}} \text{固定长度的向量}z$$
我们已经考虑了$x$和$z$的关联,那么中间层特征(feature map)和$z$的关联呢?我们记中间层特征为$\{C_{ij}(x)|i=1,2,\dots,h;j=1,2,\dots,w\}$也就是视为$h\times w$个向量的集合,我们也去算这$h\times w$个向量跟$z_x$的互信息,称为“局部互信息”。
估算方法跟全局是一样的,将每一个$C_{ij}(x)$与$z_x$拼接起来得到$[C_{ij}(x), z_x]$,相当于得到了一个更大的feature map,然后对这个feature map用多个1x1的卷积层来作为局部互信息的估算网络$T_{local}$。负样本的选取方法也是用在batch内随机打算的方案。
现在,加入局部互信息的总loss为
$$\begin{equation}\begin{aligned}&p(z|x),T_1(x,z),T_2(C_{ij}, z)=\min_{p(z|x),T_1,T_2}\Big\{\\
&\quad-\alpha\cdot\Big(\mathbb{E}_{(x,z)\sim p(z|x)\tilde{p}(x)}[\log \sigma(T_1(x,z))] + \mathbb{E}_{(x,z)\sim p(z)\tilde{p}(x)}[\log(1-\sigma(T_1(x,z))]\Big)\\
&\quad-\frac{\beta}{hw}\sum_{i,j}\Big(\mathbb{E}_{(x,z)\sim p(z|x)\tilde{p}(x)}[\log \sigma(T_2(C_{ij},z))] + \mathbb{E}_{(x,z)\sim p(z)\tilde{p}(x)}[\log(1-\sigma(T_2(C_{ij},z))]\Big)\\
&\quad+\gamma\cdot \mathbb{E}_{x\sim\tilde{p}(x)}[KL(p(z|x)\Vert q(z))]\Big\}\end{aligned}\label{eq:total-loss-5}\end{equation}$$
其他信息 #
其实,还有很多其他的信息可以考虑进去。
比如我们已经考虑了$C_{ij}$与$z$的互信息,还可以考虑的是$C_{ij}$之间的互信息,即同一张图片之间的$C_{ij}$应当是有关联的,它们的互信息应该尽可能大(正样本),而不同图片之间的$C_{ij}$应当是没关联的,它们的互信息应该尽可能小。不过我实验过,这一项的提升不是特别明显。
还有多尺度信息,可以手动在输入图片那里做多尺度的数据扩增,又或者是在编码器这些引入多尺度结构、Attention结构。诸如此类的操作,都可以考虑引入到无监督学习中,提高编码质量。
类似的word2vec #
其实,熟悉NLP中的word2vec模型原理的读者应该会感觉到:这不就是图像中的word2vec吗?
没错,在原理和做法上deep INFOMAX跟word2vec大体都一样。在word2vec中,也是随机采集负样本,然后通过判别器来区分两者的过程。这个过程我们通常称为“噪声对比估计”,我们之前也提到过,word2vec的噪声对比估计过程(负采样)的实际优化目标就是互信息。(细节请参考《“噪声对比估计”杂谈:曲径通幽之妙》)
word2vec中,固定了一个窗口大小,然后在窗口内统计词的共现(正样本)。而deep INFOMAX呢?因为只有一张图片,没有其他“词”,所以它干脆把图片分割为一个个小块,然后把一张图片当作一个窗口,图片的每个小块就是一个个词了。当然,更准确地类比的话,deep INFOMAX更像类似word2vec的那个doc2vec模型。
换个角度来想,也可以这样理解:局部互信息的引入相当于将每个小局部也看成了一个样本,这样就相当于原来的1个样本变成了$1+hw$个样本,大大增加了样本量,所以能提升效果。同时这样做也保证了图片的每一个“角落”都被用上了,因为低维压缩编码,比如$32\times 32\times 3$编码到128维,很可能左上角的$8\times 8\times 3 > 128$的区域就已经能够唯一分辨出图片出来了,但这不能代表整张图片,因此要想办法让整张图片都用上。
开源和效果图 #
参考代码 #
其实上述模型的实现代码应当说还是比较简单的(总比我复现glow模型容易几十倍~),不管用哪个框架都不困难,下面是用Keras实现的一个版本(Python 2.7 + Tensorflow 1.8 + Keras 2.2.4):
Github:https://github.com/bojone/infomax
来,上图片 #
无监督的算法好坏比较难定量判断,一般都是通过做很多下游任务看效果的。就好比当初词向量很火时,怎么定量衡量词向量的质量也是一个很头疼的问题。deep INFOMAX论文中做了很多相关实验,我这里也不重复了,只是看看它的KNN效果(通过一张图片查找最相近的k张图片)。
总的来说效果差强人意,我觉得精调之后做一个简单的以图搜图问题不大。原论文中的很多实验效果也都不错,进一步印证了该思路的威力~
cifar10 #
每一行的左边第一张是原始图片,右边9张是最邻近图片,用的是cos相似度。用欧氏距离的排序结果类似。
Tiny Imagenet #
每一行的左边第一张是原始图片,右边9张是最邻近图片,用的是cos相似度。用欧氏距离的排序结果类似。
全局 vs 局部 #
局部互信息的引入是很必要的,下面比较了只有全局互信息和只有局部互信息时的KNN的差异。
又到终点站 #
作为无监督学习的成功,将常见于NLP的互信息概念一般化、理论化,然后用到了图像中。当然,现在看来它也可以反过来用回NLP中,甚至用到其他领域,因为它已经被抽象化了,适用性很强。
deep INFOMAX整篇文章的风格我是很喜欢的:从一般化的理念(互信息最大化)到估算框架再到实际模型,思路清晰,论证完整,是我心中的理想文章的风格(除了它对先验分布的处理用了对抗网络,我认为这是没有必要的)。期待看到更多的这类文章。
转载到请包括本文地址:https://www.kexue.fm/archives/6024
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Oct. 02, 2018). 《深度学习的互信息:无监督提取特征 》[Blog post]. Retrieved from https://www.kexue.fm/archives/6024
@online{kexuefm-6024,
title={深度学习的互信息:无监督提取特征},
author={苏剑林},
year={2018},
month={Oct},
url={\url{https://www.kexue.fm/archives/6024}},
}
December 4th, 2020
苏神您好!
1.请问最小化互信息应该就添加一个负号就行吧?
2.请问如果对于需要最小化多个互信息有没有效率比较高的方式呀?比如minimize MI(x;y;z)这种情况呢?还是说需要两两分别进行最小化互信息?
February 13th, 2022
苏老师,您好,打扰了。1.想问一下第二个公式(2式)是什么($p(z)$为什么会等于右式),有点懵;非常感谢老师的指导!
老师,不好意思,刚刚脑子短路了,老师不用理会这个问题
December 26th, 2022
苏神你好,我想问下6式那个λ化简之后不是抵消前后为0了么
一个是$p(z)$,一个是$q(z)$,怎么抵消?
January 11th, 2023
苏老师,想请教您一个问题,互信息可以在不同维度的变量之间计算吗?例如MNIST数据集,输入是28*28的图片,而对应的标签只有1维,那么X是[28*28]的向量,Y就是对应的标签。每次取样本的时候,采样得到的x有28*28个值,而y只有一个值,这样两者如何计算互信息呀?
点互信息是$\frac{p(x,y)}{p(x)p(y)}$,互信息是$\iint p(x,y)\frac{p(x,y)}{p(x)p(y)}dxdy$,估计每一个亮,然后代入公式计算就行。
October 19th, 2023
苏神您好,您文章中(3)式的含义是不是“找到令$I(X;Z)$最大的$p(z|x)$”呢?
如果是的话,是不是用$p(z|x)=\mathop{argmax}\limits_{p(z|x)}\{I(X;Z)\}$更准确些呢?
是的,感谢建议,已修改。
July 17th, 2024
互信息越大意味着(大部分的)$log\frac{p(z|x)}{p(z)}$应当尽量大,这意味着p(z|x)应当远大于p(z),即对于每个x,编码器能找出专属于x的那个z,使得p(z|x)的概率远大于随机的概率p(z)。这样一来,我们就有能力只通过z就从中分辨出原始样本来。
苏神,这段话不是很理解,p(z|x)不是由x推测z的概率吗,而“只通过z就从中分辨出原始样本来。”不是指p(x|z)吗?还有“随机的概率p(z)”,这句话怎么理解呢,为什么说是随机的。
简单来说,$p(z)$是无条件分布,$p(z|x)$是条件分布,如果$p(z|x)\gg p(z)$,意味着条件$x$跟$z$的关系更密切。