argmax
在学习 gumbel softmax 之前,我们首先需要了解它的远方亲戚 argmax。
假设我们有一个概率分布向量如下:
[0.3,0.1,0.1,0.5]
对于 argmax 来说,显然每次的结果都会是 3 ,因为该位置的概率值最大。但是从概率上来说,只有 50% 的概率会选到第 3 个位置,而使用 argmax 则会有 100% 的概率选中第 3 个位置,这显然是不合理的。
基于 argmax 的采样如下:
pos = argmax(logits)
sample = logits[pos]
softmax
argmax 能直接得到最大概率的位置,我们通常需要在分类、分割任务中这样做。但是 argmax 是不可微的,这样会阻碍反向传播,于是提出了 softmax。
softmax 是 argmax 的光滑近似,其可以拉大输入向量之间的差距,并且可微,能够正常的计算梯度反向传播。
基于 softmax 的采样如下:
pro = sotfmax(logits)
sample = np.random.choice(len(logits),1, p=pro)
虽然 softmax 可微,但是基于 softmax 的采样仍然不能反向传播。
gumbel max
让我们首先在 argmax 中引入随机性——gumbel 分布,其是一种极值分布,表示某个随机变量在不同时间段中极值的概率分布。比如一个人每天喝 8 次水,显然这 8 次中的极值也是一个随机变量,该随机变量随着时间的分布即为 gumbel 分布。
其累积分布函数为
F(x)=e−e−x
我们可以通过求解其反函数来利用概率生成随机数:
G=−log(−log(x))
我们通过生成与输入向量维度相同的均匀分布向量,从 gunbel 分布中进行采样,以此获得随机性:
Gi=−log(−log(εi)),εi∈U(0,1)
于是可以得到最终的公式:
x=argmax(log(pi)+Gi)
这其实是一种重参数化的过程,具体见此
并且我们可以证明,gumbel max 输出 i 的概率刚好对应 pi。
首先我们证明输出 1 的概率是 p1,输出 1 意味着 logp1−log(−logε1) 最大,也就是说以下不等式成立:
logp1−log(−logε1)>logp2−log(−logε2)logp1−log(−logε1)>logp3−log(−logε3)⋮logp1−log(−logε1)>logpk−log(−logεk)
注意这里每个不等式是独立的,p1 与 p2 的关系并不影响 p1 和 p3 的关系。
首先分析第一个不等式,化简可得:
ε2<ε1p2/p1≤1
由于 ε 是从均匀分布中采样的,因此我们知道 ε2<ε1p2/p1 的概率就是 ε1p2/p1,对于某一个固定的 ε1,当所有不等式同时成立时,概率为:
ε1p2/p1ε1p3/p1…ε1pk/p1=ε1(p2+p3+⋯+pk)/p1=ε1(1/p1)−1
对于所有的 ε1 ,我们可以得出其概率:
∫01ε1(1/p1)−1dε1=p1
gumbel softmax
由于 argmax 不可导,我们可以使用其近似函数——softmax
x=softmax((log(pi)+Gi)/τ)
τ 表示温度,是一种退火技巧,其值越小,输出结果越接近 one hot 的形式,但同时梯度消失的情况就越严重。