mhsa注意力机制
时间: 2025-03-23 19:06:44 浏览: 102
### MHSA 的工作原理
多头自注意力机制(Multi-Head Self-Attention, MHSA)是一种扩展的自注意力机制,旨在解决单一自注意力机制可能存在的局限性。通过引入多个并行的自注意力计算路径,MHSA 能够捕获更丰富的上下文信息和特征表示。
#### 自注意力机制的基础
自注意力机制的核心在于通过对输入序列的不同位置进行加权求和来生成新的表示向量。具体来说,它依赖于三个主要组件:查询(Query)、键(Key)和值(Value)。这些组件由输入嵌入经过线性变换得到:
\[
Q = XW_Q,\ K = XW_K,\ V = XW_V
\]
其中 \(X\) 是输入矩阵,\(W_Q\)、\(W_K\) 和 \(W_V\) 分别是可训练的权重矩阵[^3]。
接着,利用点积操作计算注意力分数,并对其进行 softmax 归一化处理:
\[
\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
\]
这里,\(d_k\) 表示键维度大小,用于缩放点积以稳定梯度传播[^4]。
#### 多头注意力的设计理念
为了进一步提升模型的能力,MHSA 将上述过程重复多次,在不同的子空间上独立运行自注意力机制。每个头部都有一组独特的参数矩阵 (\(W_Q^{(i)}\), \(W_K^{(i)}\), \(W_V^{(i)}\)) 来定义特定的映射关系。最终,各个头部的结果被拼接起来并通过另一个线性层投影回原始维度:
\[
\text{MultiHead}(Q,K,V) = [\text{head}_1;...;\text{head}_h] W_O
\]
其中,
\[
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
\]
这种设计不仅增强了模型捕捉多样化模式的能力,还允许其关注不同粒度的信息。
### 实现代码示例
以下是基于 PyTorch 的简单实现版本:
```python
import torch
import torch.nn as nn
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadSelfAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.dense = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q = self.split_heads(self.wq(q), batch_size)
k = self.split_heads(self.wk(k), batch_size)
v = self.split_heads(self.wv(v), batch_size)
matmul_qk = torch.matmul(q, k.transpose(-2, -1))
scaled_attention_logits = matmul_qk / math.sqrt(self.depth)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
output = torch.matmul(attention_weights, v)
output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
final_output = self.dense(output)
return final_output, attention_weights
```
此模块实现了标准的 MHSA 结构,支持灵活调整 `d_model` 和 `num_heads` 参数以适应不同规模的任务需求。
---
阅读全文
相关推荐



















