多头注意力机制python
时间: 2025-05-17 09:24:57 浏览: 31
### 多头注意力机制的 Python 实现
多头注意力机制是一种通过多个自注意力子层并行计算来增强模型表达能力的技术。以下是基于 PyTorch 的多头注意力机制实现方法:
#### 1. 自定义多头注意力类
以下是一个完整的 `MultiHeadAttention` 类实现,它利用了 PyTorch 提供的基础模块。
```python
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.dense = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
def scaled_dot_product_attention(self, q, k, v, mask=None):
matmul_qk = torch.matmul(q, k.transpose(-2, -1)) # (..., seq_len_q, seq_len_k)
dk = torch.tensor(k.size()[-1], dtype=torch.float32)
scaled_attention_logits = matmul_qk / torch.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = torch.softmax(scaled_attention_logits, dim=-1) # (..., seq_len_q, seq_len_k)
output = torch.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask=mask)
scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous() # (batch_size, seq_len_q, num_heads, depth)
concat_attention = scaled_attention.view(batch_size, -1, self.d_model) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
```
上述代码实现了多头注意力的核心逻辑[^1],其中包含了以下几个部分:
- **线性变换**:输入经过三个不同的线性投影矩阵 \(W_Q\)、\(W_K\) 和 \(W_V\) 转化为查询(Q)、键(K)和值(V)向量。
- **分组操作**:将 Q、K 和 V 划分为多个头部以便独立计算。
- **缩放点积注意力**:计算加权后的上下文表示,并返回注意权重用于可视化或其他用途。
- **拼接与最终映射**:将各头部的结果重新组合并通过另一个全连接层得到输出。
#### 2. 使用示例
下面展示了一个简单的调用方式:
```python
# 参数设置
d_model = 512
num_heads = 8
seq_length = 10
batch_size = 32
# 初始化张量
query = torch.randn((batch_size, seq_length, d_model))
key = torch.randn((batch_size, seq_length, d_model))
value = torch.randn((batch_size, seq_length, d_model))
# 创建实例
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
# 前向传播
output, attn_weights = mha(query, key, value)
print(output.shape) # 输出形状应为 [batch_size, seq_length, d_model]
print(attn_weights.shape) # 注意力权重形状应为 [batch_size, num_heads, seq_length, seq_length]
```
此代码片段展示了如何初始化数据以及执行前向传递过程。
### 推荐系统的应用背景
在推荐系统领域,特别是序列化推荐场景下,多兴趣框架可以结合多头注意力捕捉用户的多样化偏好[^2]。这种技术能够显著提升个性化推荐的效果。
阅读全文
相关推荐



















