``` import torch import torch.nn as nn import torch.nn.functional as F class CrossAttention(nn.Module): def __init__(self, embed_dim, hidden_dim, num_heads): super(CrossAttention, self).__init__() self.embed_dim = embed_dim self.hidden_dim = hidden_dim self.num_heads = num_heads self.query_proj = nn.Linear(embed_dim, hidden_dim * num_heads) self.key_proj = nn.Linear(embed_dim, hidden_dim * num_heads) self.value_proj = nn.Linear(embed_dim, hidden_dim * num_heads) self.out_proj = nn.Linear(hidden_dim * num_heads, embed_dim) def forward(self, query, context): """ query: (batch_size, query_len, embed_dim) context: (batch_size, context_len, embed_dim) """ batch_size, query_len, _ = query.size() context_len = context.size(1) # Project input embeddings query_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim) key_proj = self.key_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim) value_proj = self.value_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim) # Transpose to get dimensions (batch_size, num_heads, len, hidden_dim) query_proj = query_proj.permute(0, 2, 1, 3) key_proj = key_proj.permute(0, 2, 1, 3) value_proj = value_proj.permute(0, 2, 1, 3) # Compute attention scores scores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5) attn_weights = F.softmax(scores, dim=-1) # Compute weighted context context = torch.matmul(attn_weights, value_proj) # Concatenate heads and project output context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, -1) output = self.out_proj(context) return output, attn_weights # Example usage: embed_dim = 512 hidden_dim = 64 num_heads = 8 cross_attention = CrossAttention(embed_dim, hidden_dim, num_heads) # Dummy data batch_size = 2 query_len = 10 context_len = 20 query = torch.randn(batch_size, query_len, embed_dim) context = torch.randn(batch_size, context_len, embed_dim) output, attn_weights = cross_attention(query, context) print(output.size()) # Should be (batch_size, query_len, embed_dim) print(attn_weights.size()) # Should be (batch_size, num_heads, query_len, context_len)```请解释该段代码
时间: 2025-03-07 09:10:57 浏览: 40
这段代码实现了交叉注意力机制(Cross Attention),这是自然语言处理和其他序列建模任务中的一种重要技术,尤其在Transformer模型及其变体中有广泛应用。下面我将逐步解释其工作原理。
### 类定义 `CrossAttention`
`CrossAttention` 继承自 PyTorch 的基础模块 `nn.Module` 并实现了多头交叉注意力机制的核心部分:
#### 初始化 (`__init__`)
- **输入参数**:
- `embed_dim`: 嵌入维度大小。
- `hidden_dim`: 每个注意头的隐藏层维度。
- `num_heads`: 注意力头的数量。
- 它创建了三个线性变换投影矩阵用于生成查询、键和值向量,并设定了最终输出的映射矩阵。
#### 正向传播 (`forward`)
接收两个张量作为输入:
- `query` 表示当前需要关注的信息(比如目标句子);
- `context` 则表示参考信息源(例如源端文本或其他上下文)。它们都经过嵌入表示形式转换而来。
接下来的操作步骤如下:
1. **Projection**: 使用线性的全连接网络对每个位置上的特征分别做变换得到对应的Q, K, V。
```python
query_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim)
```
2. **Reshape & Permute Dimensions**: 将形状调整为 `(batch_size, num_heads, length, hidden_dim)` 方便后续计算相似度得分。
```python
query_proj = query_proj.permute(0, 2, 1, 3)
```
3. **Calculate Attention Scores**: 计算 Q 和 K 矩阵之间的点积并除以根号下 d_k 来缩放分数避免梯度过大导致数值不稳定的问题;然后通过 Softmax 函数归一化概率分布获得权重系数。
```python
scores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
```
4. **Compute Weighted Context Vector**: 根据上述求得的概率加权平均V得到新的表征结果。
```python
context = torch.matmul(attn_weights, value_proj)
```
5. **Concatenate Heads and Project Output**: 最后拼接所有头部的结果并通过一个额外的线性变换降低到原始embedding size以便于进一步传递给下游的任务。
6. 返回最后的输出以及中间产生的attention weights供可视化或分析用途。
```python
return output, attn_weights
```
### 示例使用说明
构造了一个具体的实例演示如何初始化及应用这个组件,在此过程中还展示了随机生成一些假数据来进行测试验证功能是否正常运作。
---
###
阅读全文
相关推荐



















