ten自我注意力机制
时间: 2025-03-27 16:39:36 浏览: 29
### 自注意力机制概述
自注意力机制(Self-attention Mechanism)允许模型在处理序列数据时关注不同位置的信息,从而增强对上下文的理解[^1]。这种机制通过计算查询(Query)、键(Key)和值(Value)三者的交互来实现。
#### 查询、键和值的定义
对于每一个输入token,都会生成三个向量:查询(Q)、键(K)以及值(V)。这些向量通常由线性变换得到:
```python
import torch.nn as nn
class AttentionHead(nn.Module):
def __init__(self, d_model, d_k):
super().__init__()
self.query = nn.Linear(d_model, d_k)
self.key = nn.Linear(d_model, d_k)
self.value = nn.Linear(d_model, d_k)
def forward(self, x):
Q = self.query(x)
K = self.key(x)
V = self.value(x)
return Q, K, V
```
#### 计算注意力分数
接着会利用上述产生的查询、键来进行点积操作,并除以根号下维度大小以稳定梯度传播过程。之后再经过softmax函数转换成概率分布形式作为最终得分[^2]。
```python
def scaled_dot_product_attention(Q, K, V, mask=None):
"""Compute the attention scores and apply them to values."""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
output = torch.matmul(p_attn, V)
return output, p_attn
```
#### 多头注意力结构
为了捕捉更丰富的特征表示,多头注意力被引入到架构当中。这种方式能够并行地执行多次独立的关注运算并将结果拼接起来形成新的表达方式。
```python
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, d_model, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
if mask is not None:
# Same mask applied to all heads.
mask = mask.unsqueeze(1)
nbatches = query.size(0)
# Do all linear projections in batch from d_model => h x d_k
query, key, value = \
[l(x).view(nbatches, -1, self.num_heads, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]
# Apply attention on all projected vectors in batch.
x, self.attn = scaled_dot_product_attention(query, key, value, mask=mask)
# Concatenate using a view operation then final linear projection.
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.num_heads * self.d_k)
return self.linears[-1](x)
```
阅读全文
相关推荐














