多头注意力机制,注意力得分
时间: 2025-01-17 11:04:16 浏览: 52
### 多头注意力机制中的注意力得分计算
在多头注意力机制中,注意力得分为每对令牌之间的关联程度提供了量化指标。具体来说,对于查询向量 \(Q\) 和键向量 \(K\) ,注意力分数通过两者的点积来计算,并除以键向量维度的平方根用于缩放目的[^2]。
此过程可以表示为:
\[ \text{Attention Score} = \frac{Q \cdot K^T}{\sqrt{d_k}} \]
其中:
- \( Q \) 是查询矩阵,
- \( K \) 是键矩阵,
- \( d_k \) 表示键向量的维度。
为了实现更复杂的模式识别,在实际应用中通常会采用多个并行的自注意层——即所谓的“多头”。每个头部独立工作,允许模型关注输入序列的不同部分或特征。最终的结果由各个单独头产生的输出拼接而成,再经过线性变换得到最后的上下文向量。
下面是基于PyTorch框架的一个简单例子展示如何构建一个多头注意力模块:
```python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
assert embed_size % num_heads == 0, "Embed size must be divisible by number of heads"
self.depth = embed_size // num_heads
self.wq = nn.Linear(embed_size, embed_size)
self.wk = nn.Linear(embed_size, embed_size)
self.wv = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q = self.wq(q).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
k = self.wk(k).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
v = self.wv(v).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
energy = torch.matmul(q, k.transpose(-2,-1)) / ((self.depth ** 0.5))
if mask is not None:
energy = energy.masked_fill(mask==0, float('-inf'))
attention_scores = torch.softmax(energy, dim=-1)
out = torch.matmul(attention_scores, v).transpose(1, 2).contiguous().view(batch_size, -1, self.embed_size)
return self.fc_out(out), attention_scores
```
在这个代码片段里,`MultiHeadAttention` 类实现了上述提到的过程,包括创建查询、键和值权重矩阵;分割成不同‘头’;以及执行缩放后的点乘操作来获得注意力分布。最后一步是对这些分布取softmax函数作为概率分布形式返回给调用者[^1]。
阅读全文
相关推荐



















