多头注意力机制注意力机制
时间: 2025-02-09 09:11:23 浏览: 44
### 多头注意力机制概述
多头注意力机制是一种用于增强神经网络处理序列数据能力的技术,在自然语言处理和其他领域取得了显著成功。通过并行运行多个不同的注意力函数,该方法能够捕捉输入序列中的复杂关系。
#### 工作原理
在标准的单头注意力机制中,查询(Query)、键(Key)和值(Value)向量被用来计算加权求和表示[^2]。然而,这种单一视角可能不足以捕获所有重要的特征模式。因此引入了多头设计:
- **多视图关注**:创建 h 组独立的 QKV 变换参数矩阵 W_Q^i,W_K^i 和 W_V^i (i=1,...,h),从而形成 h 个平行工作的 “head”。每个 head 学习到的数据表征角度略有差异。
- **线性变换与缩放点积注意**:对于每一个 head i ,执行如下操作:
```python
Zi = softmax(Qi * Ki.T / sqrt(dk)) * Vi
```
- **拼接输出**:将各个 head 的输出按顺序连接起来,并乘上一个额外的学习权重矩阵 Wo 来获得最终的结果 Z。
```python
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.linears = clones(nn.Linear(d_model, d_model), 4)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Do all the linear projections in one go
query, key, value = [
lin(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))
]
# Apply attention on all projected vectors in batch.
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
output = torch.matmul(p_attn, value)
# Concatenate and apply final linear transformation.
output = (
output.transpose(1, 2)
.contiguous()
.view(batch_size, -1, self.num_heads * self.d_k)
)
return self.linears[-1](output)
```
此实现展示了如何构建一个多头自注意力层,其中包含了必要的张量维度转换以及掩码处理逻辑。
阅读全文
相关推荐


















