本文详细解析了多头自注意力机制的两种实现方式:堆叠单头注意力和通过权重分割的高效实现。文章介绍了多头注意力将输入划分为多个独立"头"并行处理,然后合并结果的原理,并提供了完整的PyTorch代码。通过张量重塑和转置操作实现并行计算,显著提高了效率。这种机制是Transformer架构的核心组件,对大语言模型识别复杂模式至关重要。
堆叠多层单头注意力层
在实际应用中,实现多头注意力需要创建多个自注意力机制的实例,每个实例都具有独立的权重,然后将它们的输出合并。多个自注意力机制实例的应用属于计算密集型(CPU
密集型)操作,但它对于识别复杂模式至关重要,这是基于 Transformer
的大语言模型所擅长的能力之一。
上图中的输入 inputs
是一个 6 * 3
维的向量(继承之前的文章,这里可以是任意纬度的)。然后是一个单头注意力模块,他创建了一个 6 * 2
维的上下文向量, 上下文向量中仍然有 6 个 token
,但是现在每一行是二维的嵌入。这样我们就得到了张量 Z
(这里被称为 )。
接下来我们要创建第二个上下文向量 Z
(这里被称为 )。之前我们只有一个输出,现在我们需要两个,这就是多头情况。 然后我们将这两个上下文张量连接起来,我们就有了一个四个维度的输出。
我们实现上述过程的方式就是将两个因果自注意力机制的实现堆叠起来。
上图展示了多头注意力模块的结构,该模块由多个单头注意力模块组成,彼此堆叠在一起。在因果自注意力机制中包含 query
,key
,value
三个矩阵。在多头注意力机制中有多个权重矩阵 。
可以简单地将其视为几乎执行两次相同的操作,然后将其连接起来。所以最后,我们还有两个上下文张量,如前所述,最后我们创建了这个组合上下文张量。
在代码中,我们可以通过实现一个简单的 MultiHeadAttentionWrapper
类来实现这一点,该类会堆叠多个我们之前实现的 CausalAttention 模块的实例:
class MultiHeadAttentionWrapper(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
self.heads = nn.ModuleList(
[CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
for _ in range(num_heads)]
)
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
如果我们使用这个 MultiHeadAttentionWrapper 类,并通过设置 num_heads=2 使用两个注意力头,同时将 CausalAttention 的输出维度 d_out 设置为 2,那么生成的上下文向量将是 4 维的(d_out*num_heads= 4)输出结果如下:
由此生成的 context_vecs
张量的第一个维度是 2
,因为我们有两个输入文本(输入文本被复制,因此它们的上下文向量完全相同)。第二个维度对应每个输入中的 6 个 token。第三个维度对应每个 token 的 4 维嵌入向量。
我们实现了一个 MultiHeadAttentionWrapper
,用于组合多个单头注意力模块。不过需要注意的是,在 forward
方法中,这些模块是通过 [head(x) for head in self.heads]
串行处理的。
我们可以通过并行处理各个注意力头来优化该实现。实现这一目标的一种方法是,通过矩阵乘法同时计算所有注意力头的输出,我们将在下一节中详细探讨。
通过权重分割实现多头注意力机制
在前一节中,我们创建了一个 MultiHeadAttentionWrapper,通过堆叠多个单头注意力模块来实现多头注意力。这是通过实例化并组合多个 CausalAttention 对象实现的。
与其维护两个独立的类 MultiHeadAttentionWrapper 和 CausalAttention,我们可以将这两个概念合并为一个 MultiHeadAttention 类。此外,除了简单地合并 MultiHeadAttentionWrapper 和 CausalAttention 的代码外,我们还会进行一些额外的修改,以更高效地实现多头注意力机制。
在 MultiHeadAttentionWrapper 中,多头机制是通过创建一个包含多个 CausalAttention 对象的列表(self.heads)来实现的,每个对象代表一个独立的注意力头。CausalAttention 类独立执行注意力机制,每个头的结果最终被拼接起来。相比之下,接下来的 MultiHeadAttention 类则将多头功能集成在一个单一的类中。它通过对变换后的query、key和value张量进行重塑,将输入分割成多个头,并在计算注意力后将这些头的结果组合在一起。
classMultiHeadAttention(nn.Module):
def__init__(self, d_in, d_out,
context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads #A
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) #B
self.dropout = nn.Dropout(dropout)
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
defforward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) #C
queries = self.W_query(x) #C
values = self.W_value(x) #C
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #D
values = values.view(b, num_tokens, self.num_heads, self.head_dim) #D
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) #D
keys = keys.transpose(1, 2) #E
queries = queries.transpose(1, 2) #E
values = values.transpose(1, 2) #E
attn_scores = queries @ keys.transpose(2, 3) #F
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] #G
attn_scores.masked_fill_(mask_bool, -torch.inf) #H
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = (attn_weights @ values).transpose(1, 2) #I
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) #J
context_vec = self.out_proj(context_vec) #K
return context_vec
#A 将投影维度缩小,以匹配期望的输出维度
#B 使用线性层组合头部输出
#C 张量形状:(b, num_tokens, d_out)
#D 我们通过添加 num_heads 维度来隐式地拆分矩阵。然后展开最后一个维度,使其形状从 (b, num_tokens, d_out) 转换为 (b, num_tokens, num_heads, head_dim)
#E 将张量的形状从 (b, num_tokens, num_heads, head_dim) 转置为 (b, num_heads, num_tokens, head_dim)
#F 对每个注意力头进行点积运算
#G 掩码被截断到 token 的数量
#H 使用掩码填充注意力分数
#I 张量形状:(b, num_tokens, n_heads, head_dim)
#J 将多个注意力头的输出结果合并,其中输出维度 self.d_out 等于注意力头数 self.num_heads 与每个头的维度 self.head_dim 的乘积
#K 添加一个可选的线性投影层
尽管 MultiHeadAttention 类中张量的重塑(.view)和转置(.transpose)操作看起来非常复杂,但从数学角度来看,MultiHeadAttention 类与之前的 MultiHeadAttentionWrapper 类实现的概念是相同的。
从宏观层面上看,在之前的 MultiHeadAttentionWrapper 中,我们通过堆叠多个单头注意力层的方式来组合成一个多头注意力层。而 MultiHeadAttention 类采用了一种集成的方法:它从一个多头注意力层开始,并在内部将该层分解为各个独立的注意力头,如图所示。
如图所示,query、key 和 value 张量的拆分是通过张量的重塑和转置操作实现的,这些操作分别使用了 PyTorch 的 .view 和 .transpose 方法。
首先,通过线性层对输入进行投影(分别生成 query、key 和 value),然后将其重塑为多个注意力头的形式。
关键操作是将 d_out 维度拆分成 num_heads 和 head_dim,其中 head_dim = d_out / num_heads。这种拆分通过 .view 方法实现:将形状为 (b, num_tokens, d_out) 的张量重塑为 (b, num_tokens, num_heads, head_dim)。
接下来对张量进行转置操作,将 num_heads 维度移动到 num_tokens 维度之前,使其形状变为 (b, num_heads, num_tokens, head_dim)。这种转置对于在不同注意力头之间正确对齐查询(queries)、键(keys)和值(values),并高效执行批量矩阵乘法至关重要。
为了说明这种批量矩阵乘法,假设我们有如下示例张量:
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573], #A
[0.8993, 0.0390, 0.9268, 0.7388],
[0.7179, 0.7058, 0.9156, 0.4340]],
[[0.0772, 0.3565, 0.1479, 0.5331],
[0.4066, 0.2318, 0.4545, 0.9737],
[0.4606, 0.5159, 0.4220, 0.5786]]]])
#A 该张量的形状为 (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)
接下来,我们在张量本身与张量的一个视图之间执行批量矩阵乘法操作,其中张量的视图将最后两个维度(num_tokens 和 head_dim)进行了转置:
print(a @ a.transpose(2, 3))
在这种情况下,PyTorch 中的矩阵乘法实现能够处理四维输入张量,因此矩阵乘法会在输入张量的最后两个维度(即 num_tokens 和 head_dim)之间执行,并对每个注意力头重复该操作。 上述方法成为了一种更简洁的方式,可以单独计算每个头的矩阵乘法:
irst_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)
second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)
该结果与我们之前使用批量矩阵乘法 print(a @ a.transpose(2, 3)) 时获得的结果完全相同:
在多头注意力机制中,计算完注意力权重和上下文向量之后,将所有头的上下文向量转置回形状 (b, num_tokens, num_heads, head_dim)。然后将这些向量重新塑形(展平)为 (b, num_tokens, d_out) 的形状,从而有效地将所有头的输出组合在一起。
此外,我们在多头注意力机制中添加了一个称为输出投影层(self.out_proj)的模块,用于在组合多个头的输出后进行投影。而在因果注意力类中并没有这个投影层。这个输出投影层并非绝对必要,但由于它在许多 LLM 架构中被广泛使用,因此我们在这里加上以保持完整性。
尽管 MultiHeadAttention 类由于额外的张量重塑和转置操作看起来比 MultiHeadAttentionWrapper 更复杂,但它更加高效。原因在于,我们只需执行一次矩阵乘法即可计算键(keys),对于查询(queries)和值(values)也是如此。而在 MultiHeadAttentionWrapper 中,我们需要对每个注意力头重复执行这一矩阵乘法操作,这种计算方式的开销非常大。
MultiHeadAttention 类的用法与我们之前实现的 SelfAttention 和 CausalAttention 类类似:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
从结果可以看出,输出维度是由d_out参数直接控制的:
tensor([[[0.3190, 0.4858],
[0.2943, 0.3897],
[0.2856, 0.3593],
[0.2693, 0.3873],
[0.2639, 0.3928],
[0.2575, 0.4028]],
[[0.3190, 0.4858],
[0.2943, 0.3897],
[0.2856, 0.3593],
[0.2693, 0.3873],
[0.2639, 0.3928],
[0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
在本文中,我们实现了 MultiHeadAttention
类,这将在后续实现和训练 LLM
时使用。请注意,虽然代码功能齐全,但我们使用了较小的嵌入维度和注意力头数,以便让输出结果更易于阅读。
作为对比,最小的 GPT-2 模型(1.17 亿参数)具有 12 个注意力头和 768 的上下文向量嵌入大小。而最大的 GPT-2 模型(15 亿参数)则具有 25 个注意力头和 1600 的上下文向量嵌入大小。请注意,在 GPT 模型中,token 输入的嵌入大小与上下文嵌入大小是相同的(d_in = d_out
)。
读者福利大放送:如果你对大模型感兴趣,想更加深入的学习大模型**,那么这份精心整理的大模型学习资料,绝对能帮你少走弯路、快速入门**
如果你是零基础小白,别担心——大模型入门真的没那么难,你完全可以学得会!
👉 不用你懂任何算法和数学知识,公式推导、复杂原理这些都不用操心;
👉 也不挑电脑配置,普通家用电脑完全能 hold 住,不用额外花钱升级设备;
👉 更不用你提前学 Python 之类的编程语言,零基础照样能上手。
你要做的特别简单:跟着我的讲解走,照着教程里的步骤一步步操作就行。
包括:大模型学习线路汇总、学习阶段,大模型实战案例,大模型学习视频,人工智能、机器学习、大模型书籍PDF。带你从零基础系统性的学好大模型!
现在这份资料免费分享给大家,有需要的小伙伴,直接VX扫描下方二维码就能领取啦😝↓↓↓
为什么要学习大模型?
数据显示,2023 年我国大模型相关人才缺口已突破百万,这一数字直接暴露了人才培养体系的严重滞后与供给不足。而随着人工智能技术的飞速迭代,产业对专业人才的需求将呈爆发式增长,据预测,到 2025 年这一缺口将急剧扩大至 400 万!!
大模型学习路线汇总
整体的学习路线分成L1到L4四个阶段,一步步带你从入门到进阶,从理论到实战,跟着学习路线一步步打卡,小白也能轻松学会!
大模型实战项目&配套源码
光学理论可不够,这套学习资料还包含了丰富的实战案例,让你在实战中检验成果巩固所学知识
大模型学习必看书籍PDF
我精选了一系列大模型技术的书籍和学习文档(电子版),它们由领域内的顶尖专家撰写,内容全面、深入、详尽,为你学习大模型提供坚实的理论基础。
大模型超全面试题汇总
在面试过程中可能遇到的问题,我都给大家汇总好了,能让你们在面试中游刃有余
这些资料真的有用吗?
这份资料由我和鲁为民博士(北京清华大学学士和美国加州理工学院博士)共同整理,现任上海殷泊信息科技CEO,其创立的MoPaaS云平台获Forrester全球’强劲表现者’认证,服务航天科工、国家电网等1000+企业,以第一作者在IEEE Transactions发表论文50+篇,获NASA JPL火星探测系统强化学习专利等35项中美专利。本套AI大模型课程由清华大学-加州理工双料博士、吴文俊人工智能奖得主鲁为民教授领衔研发。
资料内容涵盖了从入门到进阶的各类视频教程和实战项目,无论你是小白还是有些技术基础的技术人员,这份资料都绝对能帮助你提升薪资待遇,转行大模型岗位。
👉获取方式:
😝有需要的小伙伴,可以保存图片到VX扫描下方二维码免费领取【保证100%免费】
相信我,这套大模型系统教程将会是全网最齐全 最适合零基础的!!