多头注意力机制MHSA
时间: 2025-05-12 15:41:20 浏览: 32
### 多头注意力机制(Multi-Head Self-Attention Mechanism)
多头自注意力机制是一种用于处理序列数据的核心技术,广泛应用于自然语言处理领域以及计算机视觉中的模型架构中[^1]。它通过并行计算多个独立的注意力机制来捕获输入序列的不同子空间特征。
#### 工作原理
多头注意力机制基于标准的注意力机制扩展而来。其核心思想是将输入向量映射到不同的线性变换空间中,在这些空间中分别执行注意力操作后再拼接起来形成最终输出。具体来说:
1. 输入张量被分为三个部分:查询(Query)、键(Key)和值(Value)。这三个矩阵由原始输入经过不同权重矩阵相乘得到。
\[
Q = XW_Q,\ K = XW_K,\ V = XW_V
\]
这里 \(X\) 是输入张量,\(W_Q\), \(W_K\), 和 \(W_V\) 则分别是对应的可学习参数矩阵。
2. 对于每一个头部 (head),单独计算注意力分数并通过 softmax 函数规范化得分:
\[
Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
\]
其中 \(d_k\) 表示 Key 的维度大小,用来缩放点积防止梯度消失或爆炸问题。
3. 将所有头的结果连接在一起,并再次应用一个投影层以获得最后的输出表示形式。
以下是 Python 实现的一个简单例子:
```python
import torch
import torch.nn as nn
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadSelfAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.dense = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
def forward(self, v, k, q, mask=None):
batch_size = q.size(0)
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention, attention_weights = self.scaled_dot_product_attention(
q, k, v, mask=mask)
scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()
concat_attention = scaled_attention.view(batch_size, -1, self.d_model)
output = self.dense(concat_attention)
return output, attention_weights
@staticmethod
def scaled_dot_product_attention(q, k, v, mask=None):
matmul_qk = torch.matmul(q, k.transpose(-2, -1))
dk = k.shape[-1]
scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
outputs = torch.matmul(attention_weights, v)
return outputs, attention_weights
```
上述代码定义了一个 `MultiHeadSelfAttention` 类,其中包含了分割头的操作、点积注意力函数及其后续处理过程。
阅读全文
相关推荐


















