多头注意力机制python手撕
时间: 2025-04-02 15:06:16 浏览: 42
### 多头注意力机制的 Python 实现
多头注意力机制(Multi-Head Attention, MHA)是一种强大的神经网络组件,广泛应用于自然语言处理领域中的 Transformer 架构。它通过多个独立的注意力头来捕获输入数据的不同特征和模式[^1]。
以下是基于 PyTorch 的多头注意力机制实现代码:
#### 代码实现
```python
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
def scaled_dot_product_attention(self, q, k, v, mask=None):
matmul_qk = torch.matmul(q, k.transpose(-2, -1)) # (..., seq_len_q, seq_len_k)
dk = k.size()[-1]
scaled_attention_logits = matmul_qk / math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = torch.softmax(scaled_attention_logits, dim=-1) # (..., seq_len_q, seq_len_k)
output = torch.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
Q = self.split_heads(self.W_q(query), batch_size) # (batch_size, num_heads, seq_len_q, depth)
K = self.split_heads(self.W_k(key), batch_size) # (batch_size, num_heads, seq_len_k, depth)
V = self.split_heads(self.W_v(value), batch_size) # (batch_size, num_heads, seq_len_v, depth)
scaled_attention, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask=mask)
scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous() # (batch_size, seq_len_q, num_heads, depth)
concat_attention = scaled_attention.view(batch_size, -1, self.d_model) # (batch_size, seq_len_q, d_model)
output = self.W_o(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
```
上述代码实现了标准的多头注意力机制,其核心逻辑包括以下几个方面:
1. **线性变换**:`W_q`, `W_k`, 和 `W_v` 对查询、键和值向量进行线性投影。
2. **分组操作**:将嵌入维度拆分为多个头部,以便并行计算不同子空间内的注意力权重[^3]。
3. **缩放点积注意力**:通过对齐查询与键的关系,并应用 softmax 函数得到注意力分数[^4]。
4. **输出转换**:最后通过 `W_o` 将所有头部的结果重新组合成单一表示。
此实现在训练过程中有助于提升模型的学习能力和泛化性能,尤其是在处理长距离依赖关系时表现尤为突出。
---
###
阅读全文
相关推荐


















