FLASH:一种高效的Transformer设计

FLASH是一种针对长序列处理的高效Transformer模型,通过GAU单元融合注意力机制与前馈网络,实现更快速且节省显存的目标。文章介绍了FLASH的工作原理,包括局部与全局注意力计算方法,以及在不同任务上的实证表现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

背景

近年来,Transformer凭借其优秀的设计,在文本、图像、语音等方向大杀四方。但是由于其attention的二次复杂度限制了其在长序列上的应用。本文提出了一种快(速度快)、省(省显存)的模型FLASH(Fast Linear Attention with a Single Head),在长序列的表现远远高于标准的Transformer。

模型介绍

GAU(Gated Attention Unit)

在标准的Transformer结构中,多头注意力和FFN是交替连接的。GLU那篇论文中,将FFN替换成基于门控的线性单元,发现效果会变好。因此,我们先简单了解一下门控单元GLU的计算,如下左图:

具体计算:

也就是将输入X分别经过放射变换(线性映射+激活函数)得到U,VU,VU,V。然后再将U,VU,VU,V进行点积,最后再进行线性映射,得到门控线性单元的输出。

上述的GLU中没有对token两两进行注意力计算,如果在上面的U,VU,VU,V中引入注意力,那岂不是就省了前面的多头注意力计算了。如下式:

U,VU,VU,V进行点积计算的时候,如果给VVV乘一个注意力矩阵AAA(维度为nxn,其中n为序列的长度),那岂不是就引入了注意力信息。

基于此,本文就提出了一种新的结构GAU。主要是给出了一种注意力矩阵A的计算方法。具体计算如下:

对输入X进行放射变换(线性映射+激活函数)得到Z,然后对Z分别进行Q,K\mathcal{Q},\mathcal{K}Q,K变换,就是对Z中的每一个标量进行平移等运算。 这里的Q,K\mathcal{Q},\mathcal{K}Q,K变化类似于LayerNorm中的α,β\alpha, \betaα,β,是可训练的。 然后将两种变换的结果进行矩阵乘。最后再经过一个relu2relu^2relu2的激活函数(relu2relu^2relu2是将relu的计算结果平方),得到最终的注意力矩阵。然后和门控注意力单元中的V进行相乘,这样就在门控注意力单元中引入了注意力信息。具体结果如下图所示:

注意:可能是GLU对attention的依赖没有那么强,因此,作者在实验中只用了一个注意力头。

在GAU的实验中,作者固定e=2d,那"n层Attention+n层FFN"的标准Transformer模型,对应的就是"2n层GAU"的新模型,即该模型为FLASH-Quad。其中Quad表明复杂度依然是二次的。即:FLASH的二次复杂度版本。

Fast Linear Attention with GAU

可能有读者发现,在上述的GAU中,你只是将attention和FFN合并起来,替代了标准的attention+FNN,并没有解决attention的二次复杂度呀?对,因此,作者提出了一种快速计算注意力的方法。

过去,在解决注意力的二次复杂度问题上,有两种主流方法: (1)将注意力计算稀疏化、(2)将注意力计算线性化。稀疏化即人为根据先验知识规定哪些token可以进行注意力计算(典型代表: Longformer、BigBird等)。线性化则是提出另外的方法,去逼近标准注意力的效果(典型代表: Linformer、Performer等),如下公式所示:

正常的注意力是将Q,KQ,KQ,K进行矩阵乘,接着经过softmax,最后乘V。如果将K和V先进行乘,则可以大大减少计算量。假设Q,K,VQ,K,VQ,K,V的维度为:(m,d)(m,d)(m,d),则标准注意力的计算量为:
m∗d∗m+m∗d∗mm*d*m+m*d*mmdm+mdm,即: 2dm22dm^22dm2,是跟序列长度m成平方正比。如果先算K乘V,则计算量为:d∗m∗d+d∗m∗dd*m*d+d*m*ddmd+dmd,即:2md22md^22md2,是跟序列长度m成一次正比。所以,第二种方法随着序列的边长,效率会远高于第一种方法。

本文则是根据上述两种方式,结合"稀疏化"和"线性化"的优点,提出了一种"局部+全局"的分块混合的注意力计算方法。

首先是分块注意力的计算,假设序列长度为n,每个块的维度为c,则可分成n/c个块(默认可整除)。Ug,Vg∈Rc×e,Zg∈Rc×s\boldsymbol{U}_g,\boldsymbol{V}_g\in\mathbb{R}^{c\times e},\boldsymbol{Z}_g\in\mathbb{R}^{c\times s}Ug,VgRc×e,ZgRc×s,其中g指的是第g个块。将Zg\boldsymbol{Z}_gZg通过四个放射变换(线性映射+激活)分别得到Qgquad,Kgquad,Qglin,Kglin\boldsymbol{Q}_g^{\text{quad}},\boldsymbol{K}_g^{\text{quad}},\boldsymbol{Q}_g^{\text{lin}},\boldsymbol{K}_g^{\text{lin}}Qgquad,Kgquad,Qglin,Kglin。则块内注意力计算如下:

可以看出上述公式很好理解,不做过多描述。接下来算算其复杂度。每个块内注意力计算复杂度为c2c^2c2,有n/c个块,则块注意力计算整体的复杂度为:(n/c)∗c2(n/c) * c^2(n/c)c2,即nc,也就是正比于n。

接着用Qglin,Kglin\boldsymbol{Q}_g^{\text{lin}},\boldsymbol{K}_g^{\text{lin}}Qglin,Kglin进行全局的attention计算。这里采用的是上述介绍的注意力线性化方法。计算如下:


上述(7)式就是全局的两两计算。(8)式主要是在生成任务中,对标的是带有mask的多头注意力。

最后将两种attention结果整合到GAU中,得到线性版的GAU网络,计算如下:

作者在论文中还贴出来注意力计算的代码,如下所示:

至此,论文就介绍完了,下面简单看一下实验结果。

实验

首先,作者对比了GAU、多头注意力+FFN、以及多头注意力+GLU三种结构,在自回归任务和MLM任务上的表现,如下图:

横轴为速度,纵轴为效果,越靠右上,效果越好。上述实验是在长度为512上的效果对比。可以发现,GAU的在相同效果的前提下,速度更快;在相同速度的前提下,效果更好。

从上表中可以看出,虽然FLASH-Quad也是二次复杂度,但是也比标准的Transformer效果好,速度也更快。另外,随着序列的逐渐变长,FLASH的速度远远快于标准的Transformer。

最后看看消融实验,如下图所示:

上面MF-TFM++模型采用的是多头注意力+FFN的结构,只是多头注意力采用的线性注意力+块注意力。也就是本文提出的注意力计算。从消融实验中,可以看出一个很有用的信息,即localOnly attention比GlobalOnly attention更重要。

本文参考:
[1] 论文: https://arxiv.org/abs/2202.10447
[2] GLU: https://arxiv.org/abs/2002.05202
[3] 苏剑林. (Feb. 25, 2022). 《FLASH:可能是近来最有意思的高效Transformer设计 》[Blog post]. Retrieved from https://kexue.fm/archives/8934

### 2025年最新的高效Transformer网络模型及其架构 随着研究的发展,到2025年,几种新的Transformer变体已经显著提高了效率和性能。这些改进主要集中在优化计算资源利用、减少参数量以及增强模型泛化能力上。 #### Perceiver IO 架构 Perceiver IO 是一种扩展了原始Perceiver结构的模型,在处理大规模输入数据方面表现出色[^1]。该模型通过引入交叉注意力机制来连接编码器与解码器部分,从而允许更灵活的信息流动模式。这种设计使得即使面对非常大的图像或其他高维传感器读数时也能保持良好的表现力。 ```python class PerceiverIO(nn.Module): def __init__(self, input_channels=3, depth=8, num_latents=256, latent_dim=512): super().__init__() self.latent_layer = nn.Parameter(torch.randn(num_latents, latent_dim)) def forward(self, x): b, c, h, w = x.shape latents = repeat(self.latent_layer, 'n d -> b n d', b=b) for _ in range(depth): cross_attn_to_x = CrossAttention(dim=c) attn_out = cross_attn_to_x(latents, context=x.flatten(2).permute(0, 2, 1)) # More layers... ``` #### Flash Attention 技术 Flash Attention 提出了一个新的视角来看待自注意层中的softmax操作,并提出了近似方法以加速其执行速度并降低内存占用率。这种方法特别适用于长序列上的任务,因为它能够在几乎不损失精度的情况下极大地提高推理阶段的速度。 #### Efficient Global Context (EGC) Modules Efficient Global Context modules 被用来替代传统的全连接层或卷积核大小较大的组件。它能够有效地捕捉全局上下文信息而不增加过多额外负担给整个框架。这有助于构建更加紧凑但依然具备强大表达能力的神经网络体系。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

传道解惑也

打赏一下咯

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值