多头注意力机制的维度
时间: 2025-05-02 08:50:58 浏览: 37
### 多头注意力机制中的维度计算与实现
多头注意力机制是一种通过分解输入向量到多个子空间来提升模型表达能力的技术。其核心思想是对每个子空间分别应用自注意力机制,从而获得更丰富的特征表示。
#### 1. 输入张量的形状
在多头注意力机制中,输入通常是一个三维张量 \( X \in \mathbb{R}^{N \times L \times D_{\text{model}}} \),其中:
- \( N \) 是批量大小,
- \( L \) 是序列长度,
- \( D_{\text{model}} \) 是嵌入维度[^2]。
#### 2. 投影矩阵的作用
为了将输入投影到不同的子空间,定义三个线性变换矩阵 \( W_Q, W_K, W_V \),它们用于生成查询 (Query)、键 (Key) 和值 (Value) 向量。这些矩阵的尺寸分别为:
- 查询矩阵 \( W_Q \in \mathbb{R}^{D_{\text{model}} \times d_k} \),
- 键矩阵 \( W_K \in \mathbb{R}^{D_{\text{model}} \times d_k} \),
- 值矩阵 \( W_V \in \mathbb{R}^{D_{\text{model}} \times d_v} \),
其中 \( d_k = d_v = \frac{D_{\text{model}}}{h} \),\( h \) 表示注意力头的数量。
#### 3. 子空间划分
对于每个多头注意力层,输入会被划分为 \( h \) 个独立的子空间,在每个子空间中执行标准的缩放点积注意力操作:
\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]
最终,所有头的结果会按最后一个维度拼接起来,并通过一个额外的线性变换恢复原始维度:
```python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.fc_out = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q = self.split_heads(self.W_q(q), batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(self.W_k(k), batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(self.W_v(v), batch_size) # (batch_size, num_heads, seq_len_v, depth)
scaled_attention = self.scaled_dot_product_attention(q, k, v, mask)
concat_attention = scaled_attention.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.fc_out(concat_attention)
return output
@staticmethod
def scaled_dot_product_attention(q, k, v, mask=None):
matmul_qk = torch.matmul(q, k.transpose(-2, -1)) # (..., seq_len_q, seq_len_k)
dk = k.shape[-1]
scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk))
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = torch.softmax(scaled_attention_logits, dim=-1) # (..., seq_len_q, seq_len_k)
output = torch.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output
```
上述代码展示了如何利用 PyTorch 实现一个多头注意力模块[^4]。
#### 4. TensorFlow 的实现方式
在 TensorFlow 中,可以通过 `tf.keras.layers.MultiHeadAttention` 提供的功能快速构建多头注意力层。以下是基于 TensorFlow 的简单实现示例:
```python
import tensorflow as tf
class MultiHeadAttentionLayer(tf.keras.Model):
def __init__(self, d_model, num_heads):
super(MultiHeadAttentionLayer, self).__init__()
self.attention_layer = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model//num_heads)
def call(self, query, value, key, mask=None):
context_vector, _ = self.attention_layer(query=query, value=value, key=key, attention_mask=mask)
return context_vector
```
此代码片段说明了如何使用 TensorFlow 高级 API 构建类似的结构[^1]。
---
阅读全文
相关推荐


















