def mha(x, attn, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] """ Task: Complete the code of the multi-head attention Input: x: Tensor attn: dictionary that load from gpt2 weight. c_attn and c_proj are the params of two linear layer n_head: number of head Output: Tensorying multi-head attention and linear transformation, shape [n_seq, n_embd]. """ c_attn, c_proj = attn['c_attn'], attn['c_proj'] # qkv projection x = linear(x, c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd] # Split into qkv """ Task: Split the q,k,v matrix from the tensor x Notes: [n_seq, 3*n_embd] -> 3 * [n_seq, n_embd] """ torch.split(x,) qkv = None # need to modify # Split into heads qkv_heads = [qkv_part.chunk(n_head, dim=-1) for qkv_part in qkv] # 3 * [n_seq, n_embd] -> 3 * n_head * [n_seq, n_embd/n_head] qkv_heads = list(zip(*qkv_heads)) # [3, n_head, n_seq, n_embd/n_head] # Causal mask to hide future inputs from being attended to """ Task: Construct mask matrix Notes: | 0 -inf -inf ... -inf | | 0 0 -inf ... -inf | | 0 0 0 ... -inf | |... ... ... ... ... | | 0 0 0 ... 0 | Mask is a tensor whose dimension is [n_seq, n_seq] """ causal_mask = None # need to modify # Perform attention over each head out_heads = [attention(q, k, v, causal_mask) for q, k, v in qkv_heads] # n_head * [n_seq, n_embd/n_head] # Merge heads """ Task: merge multi-heads results Notes: n_head * [n_seq, n_embd/n_head] --> [n_seq, n_embd] """ x = None # need to modify # Out projection x = linear(x, c_proj) # [n_seq, n_embd] -> [n_seq, n_embd] return x 注释解释一下
时间: 2025-06-25 21:27:19 浏览: 10
<think>好的,用户想要为GPT-2风格的多头注意力机制添加详细的逐行注释,并解释各部分的功能和关键概念。我需要先回忆多头注意力的结构,以及GPT-2的实现特点。根据提供的引用,GPT-2使用了Transformer的decoder部分,所以可能需要考虑掩码机制。
首先,多头注意力是将输入分成多个头,每个头进行独立的注意力计算,最后合并结果。代码中应该涉及线性变换、分割张量、计算注意力分数、应用softmax,以及最后的线性投影。
用户提供的参考引用3中提到了注意力分数的计算方式,包括缩放点积和softmax应用,这需要体现在注释中。同时,代码可能涉及掩码处理,因为GPT-2是自回归模型,需要防止未来信息泄露。
我需要确保注释覆盖每个关键步骤,比如线性层的用途、头的分割、注意力计算、掩码应用、合并头等。还要解释参数如d_model, num_heads, d_k的作用,以及为什么需要缩放因子sqrt(d_k)。
另外,要注意代码中的变量命名是否与标准实现一致,比如query, key, value的线性变换,以及最后的output投影层。还要检查是否有残差连接和层归一化,但根据问题描述,可能用户只关注注意力函数本身。
最后,需要确认代码示例的正确性,比如张量形状的变化,尤其是batch_size和序列长度的处理,以及多头分割后的维度调整。确保注释清晰,能让读者理解每一步的目的和背后的原理。</think>以下是为GPT-2风格的多头注意力函数添加的逐行注释实现,包含关键概念解释和代码结构分析:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0 # 确保模型维度能被头数整除
self.d_model = d_model # 总特征维度(如768)
self.num_heads = num_heads # 注意力头数(如12)
self.d_k = d_model // num_heads # 每个头的特征维度(如64)
# 定义Q/K/V的线性变换层(不包含偏置项)
self.q_linear = nn.Linear(d_model, d_model, bias=False)
self.k_linear = nn.Linear(d_model, d_model, bias=False)
self.v_linear = nn.Linear(d_model, d_model, bias=False)
# 最终输出线性变换层
self.output_linear = nn.Linear(d_model, d_model, bias=False)
def forward(self, query, key, value, mask=None):
"""
:param query: [batch_size, seq_len, d_model]
:param key: [batch_size, seq_len, d_model]
:param value: [batch_size, seq_len, d_model]
:param mask: [batch_size, 1, 1, seq_len](用于decoder的掩码)
:return: [batch_size, seq_len, d_model]
"""
batch_size = query.size(0)
# 步骤1:线性变换得到Q/K/V
Q = self.q_linear(query) # [batch_size, seq_len, d_model]
K = self.k_linear(key) # 计算键的线性投影[^3]
V = self.v_linear(value) # 计算值的线性投影
# 步骤2:分割多头(reshape + transpose)
# 将特征维度分割为num_heads个d_k维度
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) # [batch, heads, seq_len, d_k]
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
# 步骤3:计算缩放点积注意力
# 公式:Attention(Q,K,V)=softmax(QK^T/sqrt(d_k))V
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
# 步骤4:应用掩码(仅decoder需要)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf')) # 将掩码位置设为负无穷
# 步骤5:应用softmax得到注意力权重
attention_weights = F.softmax(scores, dim=-1) # [batch, heads, seq_len, seq_len]
# 步骤6:加权求和得到注意力输出
attention_output = torch.matmul(attention_weights, V) # [batch, heads, seq_len, d_k]
# 步骤7:合并多头(transpose + reshape)
attention_output = attention_output.transpose(1,2).contiguous() # [batch, seq_len, heads, d_k]
attention_output = attention_output.view(batch_size, -1, self.d_model) # 合并回原始维度
# 步骤8:最终线性变换
output = self.output_linear(attention_output) # [batch_size, seq_len, d_model]
return output, attention_weights
```
### 关键概念说明
1. **多头机制**:将输入特征分成多个子空间并行处理,增强模型捕捉不同位置关系的能力
2. **缩放因子**:$\sqrt{d_k}$用于防止点积结果过大导致softmax梯度消失
3. **掩码机制**:在decoder中阻止当前位置关注后续位置,保持自回归特性[^2]
4. **参数共享**:每个注意力头共享相同的线性变换矩阵,通过分割实现不同关注模式
### 性能优化特点
1. 使用矩阵运算并行处理所有注意力头
2. 通过转置和reshape操作避免显式循环
3. 保持输入输出维度一致,便于残差连接[^2]
4. 采用参数化的线性变换实现灵活的特征组合
阅读全文
相关推荐

















