自注意力机制,多头自注意力机制
时间: 2025-01-08 22:01:37 浏览: 49
### 自注意力机制概述
自注意力机制是一种用于捕捉序列数据内部依赖关系的方法,尤其适用于自然语言处理领域。通过计算输入序列中不同位置之间的相互影响程度来构建上下文表示[^1]。
#### 关键概念解释:
- **查询(Query)**, **键(Key)** 和 **值(Value)** 是构成自注意力层的核心组件。这三个向量分别代表当前词项、其他可能关联的词项及其对应的语义贡献。
- **缩放点积注意力(Scaled Dot-Product Attention)** 计算方式如下所示:
\[
\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
\]
其中 \( Q \),\( K \),和 \( V \) 分别对应于查询矩阵、键矩阵和值矩阵;而 \( d_k \) 表示键维度大小。此公式实现了加权求和操作,权重由 softmax 函数决定,从而使得重要部分获得更高分数[^2]。
```python
import numpy as np
def scaled_dot_product_attention(Q, K, V):
"""实现缩放点积注意力"""
dk = K.shape[-1]
scores = np.matmul(Q, K.T) / np.sqrt(dk)
attention_weights = softmax(scores)
output = np.matmul(attention_weights, V)
return output
def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=-1, keepdims=True)
# 示例参数初始化 (假设 batch_size=1, seq_len=4, dim_model=3)
np.random.seed(0)
queries = np.random.randn(1, 4, 3)
keys = values = queries.copy()
output = scaled_dot_product_attention(queries, keys, values)
print(output)
```
### 多头自注意力机制详解
为了进一步增强模型表达能力并允许其关注来自不同表征子空间的信息流,引入了多头(self-multihead)结构。具体来说就是将原始特征映射成多个低维空间下的 Query/Key/Value 对,并独立执行上述提到的标准自注意力运算后再拼接起来形成最终输出。
\[
\begin{aligned}
&\mathrm { MultiHead } ( Q , K , V ) \\
=& \operatorname { Concat }\left( head _ { 1} , \ldots , head_ h \right) W ^ O\\
&where \\
&head_i=\operatorname { Attention }\left(QW_I^i,\; KW_K^i,\; VW_V^i\right), i \in [1,h]\\
&W_O,W_Q^{(i)},W_K^{(i)},W_V^{(i)} are learnable parameters.
\end{aligned}
\]
这里 `h` 表示头部数量,每个头部都有自己的线性变换参数集 (\(W_Q^{(i)}\) ,\(W_K^{(i)}\) ,\(W_V^{(i)})\) 。最后再经过一次投影得到整体输出。
```python
class MultiHeadAttention:
def __init__(self, num_heads, embed_dim):
super().__init__()
self.num_heads = num_heads
self.embed_dim = embed_dim
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
self.head_dim = embed_dim // num_heads
# 初始化权重矩阵
self.q_linear = nn.Linear(embed_dim, embed_dim)
self.k_linear = nn.Linear(embed_dim, embed_dim)
self.v_linear = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value):
batch_size = query.size(0)
# 将输入转换为适合多头的形式
q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 执行缩放点乘注意力
attn_output = []
for qi, ki, vi in zip(q.split(1,dim=1), k.split(1,dim=1), v.split(1,dim=1)):
out = scaled_dot_product_attention(qi.squeeze(), ki.squeeze(), vi.squeeze())
attn_output.append(out.unsqueeze(dim=1))
concat_attn_outputs = torch.cat(attn_output, dim=1)
final_output = concat_attn_outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
return self.out_proj(final_output)
```
阅读全文
相关推荐


















