多头注意力机制代码
时间: 2025-05-19 18:22:19 浏览: 21
### 多头注意力机制的代码实现
以下是基于 PyTorch 的多头自注意力机制(Multi-Head Self-Attention)的代码实现:
```python
import torch
import torch.nn as nn
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadSelfAttention, self).__init__()
assert (embed_size % heads == 0), "Embedding size is not divisible by number of heads."
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
# Linear layers for queries, keys, and values
self.W_q = nn.Linear(embed_size, embed_size)
self.W_k = nn.Linear(embed_size, embed_size)
self.W_v = nn.Linear(embed_size, embed_size)
# Final linear layer after concatenating all the heads' outputs
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, query, key, value):
N = query.shape[0] # Batch size
seq_len_query = query.shape[1]
seq_len_key = key.shape[1]
# Pass through respective linear layers
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# Split embedding into `heads` parts
Q = Q.view(N, seq_len_query, self.heads, self.head_dim).permute(0, 2, 1, 3) # Shape: (N, heads, seq_len_query, head_dim)
K = K.view(N, seq_len_key, self.heads, self.head_dim).permute(0, 2, 1, 3) # Shape: (N, heads, seq_len_key, head_dim)
V = V.view(N, seq_len_key, self.heads, self.head_dim).permute(0, 2, 1, 3) # Shape: (N, heads, seq_len_key, head_dim)
# Compute attention scores using scaled dot-product formula
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.head_dim) # Shape: (N, heads, seq_len_query, seq_len_key)
# Apply softmax to get attention weights
attention_scores = torch.softmax(energy, dim=-1) # Shape: (N, heads, seq_len_query, seq_len_key)
# Multiply with Value vectors
out = torch.matmul(attention_scores, V) # Shape: (N, heads, seq_len_query, head_dim)
# Concatenate all heads together again
out = out.permute(0, 2, 1, 3).contiguous().view(N, seq_len_query, self.embed_size) # Shape: (N, seq_len_query, embed_size)
# Pass concatenated output through final fully connected layer
out = self.fc_out(out) # Shape: (N, seq_len_query, embed_size)
return out
```
上述代码实现了多头自注意力机制的核心功能,其中包含了查询(Q)、键(K)、值(V)向量的线性变换以及缩放点积计算过程[^2]。
#### 关键点解析
1. **嵌入维度拆分**:为了支持多个头部(Heads),输入的嵌入被均匀分配给每个头部。这一步通过矩阵操作完成。
2. **缩放点积注意力**:在计算能量分数时,采用缩放因子 \( \sqrt{\text{head\_dim}} \),以稳定梯度传播并提高数值稳定性。
3. **Softmax 函数应用**:通过对最后一个维度执行 Softmax 操作,得到标准化后的注意力权重。
4. **输出拼接与转换**:各头部的结果经过重新排列后连接成原始形状,并传递至最终全连接层进行进一步处理。
此实现在讲解过程中省略了 padding mask 和 sequence mask 部分逻辑,因此适用于无掩码场景下的多头注意力机制[^1]。
阅读全文
相关推荐

















