Skip to main content

超越自注意力:External Attention

论文名称:Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks

作者:Meng-Hao Guo, Zheng-Ning Liu, Tai-Jiang Mu, Shi-Min Hu, Senior Member

Code:https://github.com/MenghaoGuo/EANet

前言

自从 SelfAttentionSelf-Attention 在 NLP 中被提出之后,许多工作尝试将其引入计算机视觉的各种任务之中,即使其被证明是一种行之有效的方法,其缺点终将会阻碍自身的进一步发展:

  1. SelfAttentionSelf-Attention 的计算复杂度呈二次增长
  2. 只能学习到单个样本中的长距离关系,但是无法考虑所有样本之间的关联

因此,提出了 ExternalAttentionExternal-Attention,仅仅通过两个可学习的单元,实现了线性的时间复杂度,并且能够考虑到所有样本之间的关联,与此同时,对整个网络起着重要的正则作用,有利于增强网络的泛化性

External Attention

image-20210614182722717

Self Attention

公式为:

Fout=softmax(WQF(WKF)T)WVFF_{out}=softmax(W_QF(W_KF)^T)W_VF

具体内容移步

另一个简化 SelfAttentionSelf\quad Attention 的公式为

Fout=softmax(FFT)FF_{out}=softmax(FF^T)F

其时间复杂度为 O(dN2)O(dN^2)

在计算机视觉的各类任务中,显然,我们不需要 NXN 的注意力图,因为像素点之间关系并不想 NLP 中词语之间的关系那么密切,因而有的工作提出了在 patch 上计算注意力。

即便能在一定程度上缓解计算消耗,SelfAttentionSelf-Attention 忽略了不同样本间数据的关系,在一定程度上限制了其灵活度和泛化性

External Attention

不再通过原始输入来获得 K,VK,V,而使用两个独立的 MemoryUnitMemory\quad Unit,公式如下:

A=αi,j=Norm(FMKT)A=\alpha_{i,j}=Norm(FM^T_K)

在原文中,这里的 αi,j\alpha_{i,j} 指的是第 ii 个像素与 MkM_k 中第 jj 行的相似度,我认为体现的并不明显(emmmm 我太菜了)

让我们做一个简单的推理(这里的矩阵的长宽都假设为 N):

FMKT=[f1f2fnfn+1fn+2f2nfn2n+1fn2n+2fn2][m1,1m2,1mn,1m1,2m2,2mn,2m1,nm2,nmn,n]FM_K^T= \left[\begin{array} {c}f_{1} & f_{2}&\cdots&f_{n}\\ f_{n+1}& f_{n+2}&\cdots&f_{2n}\\ \vdots&\vdots&\ddots&\vdots\\ f_{n^2-n+1}&f_{n^2-n+2}&\cdots&f_{n^2} \end{array}\right] \left[\begin{array} {c}m_{1,1} & m_{2,1}&\cdots&m_{n,1}\\ m_{1,2}& m_{2,2}&\cdots&m_{n,2}\\ \vdots&\vdots&\ddots&\vdots\\ m_{1,n}&m_{2,n}&\cdots&m_{n,n} \end{array}\right]

根据矩阵的运算规则,对于 fif_i(它在第 jj 列),那么 MKTM_K^T 的第 jj 行的所有元素都会和 fif_i 相乘(注意这里是转置之后的行,在原来的 MKM_K 上应该是列,更加迷惑了),然而它们并不会加在一起,而是分布在一个维度 (1,n)(1,n) 的矩阵中,并且还会有其他数据的影响,因为 fif_i 所在的一行会与 MKTM_K^T 的每一列乘加,

综上,我不知道相似度体现在那里

总之,ES 将时间复杂度降到了线性,其输出为:

Fout=AMVF_{out}=AM_V

多头 External Attention 的实现如下

image-20210614203458185

归一化

我们知道,SoftmaxSoftmax 对输入的尺度大小十分敏感,因此在自注意力中使用了缩放点积来减小影响,具体移步

在 ES 中,使用 DoubleNormalizationDouble\quad Normalization 分别在行和列上进行归一化

这里的 DoubleNormalizationDouble\quad Normalization 实际上就是先对列进行 SoftmaxSoftmax 再对行进行 l1norml_1norm

代码

class External_attention(Module):
'''
Arguments:
c (int): The input and output channel number.
'''
def __init__(self, c):
super(External_attention, self).__init__()

self.conv1 = nn.Conv2d(c, c, 1)

self.k = 64
self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)

self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
self.linear_1.weight = self.linear_0.weight.permute(1, 0, 2) #初始化成一样是为了让 Mk 和 Mv 最开始时候是一致的,类似于 self-attention 里面的 K 和 V 都是从 X 经过线性变换来的,这种设置在训练初始阶段会稳定一些,但是貌似不会影响最终结果。

self.conv2 = nn.Sequential(
nn.Conv2d(c, c, 1, bias=False),
nn.BatchNorm(c))

self.relu = nn.ReLU()

def execute(self, x):
idn = x
x = self.conv1(x)

b, c, h, w = x.size()
n = h*w
x = x.view(b, c, h*w) # b * c * n

attn = self.linear_0(x) # b, k, n
attn = torch.softmax(attn, dim=-1) # b, k, n

attn = attn / (1e-9 + attn.sum(dim=1, keepdims=True)) # # b, k, n
x = self.linear_1(attn) # b, c, n

x = x.view(b, c, h, w)
x = self.conv2(x)
x = x + idn
x = self.relu(x)
return x

多头 EA 请自行前往仓库查找

思考

先看大佬怎么说:https://www.zhihu.com/search?type=content&q=External%20Attention

感觉 External Attention 与 Bottleneck 有一定的相似程度,首尾通过 1X1 卷积实现残差连接,不过在连接的过程中加入了 Softmax 和 l1norm 起到一定信息交换的作用

关于 MemoryUionMemory\quad Uion:

在之前就有工作把 position-wise 的前馈层当作一个 Memory 模块

关于 DoubleNormalizationDouble\quad Normalization

如果你看完了输入尺度对 SoftmaxSoftmax 的影响,应该会思考,为什么要将 l1norm 添加在后面呢?

当输入特征方差过大时,SoftmaxSoftmax 输出的会是一个接近于 OnehotOne-hot 的向量,此时再做 l1norm 显然不会有什么变化吧

或者说这里的 SoftmaxSoftmax 可以提取的是全局信息?

关于相似度:见 2.2