class MultiHeadAttentionBlock(torch.nn.Module): def __init__(self, dim_val, dim_attn, n_heads): super(MultiHeadAttentionBlock, self).__init__() self.heads = [] for i in range(n_heads): self.heads.append(AttentionBlock(dim_val, dim_attn)) self.heads = nn.ModuleList(self.heads) self.fc = nn.Linear(n_heads * dim_val, dim_val, bias = False) def forward(self, x, kv = None): a = [] for h in self.heads: a.append(h(x, kv = kv)) a = torch.stack(a, dim = -1) a = a.flatten(start_dim = 2) x = self.fc(a) return x
时间: 2025-04-07 15:09:37 浏览: 32
这是一个基于 PyTorch 的多头注意力机制 (Multi-Head Attention) 模块的实现。下面是对这个模块的功能以及其实现细节的一个简要解释:
### 类 `MultiHeadAttentionBlock` 解析
#### 构造函数 (`__init__`)
```python
def __init__(self, dim_val, dim_attn, n_heads):
```
**参数说明**:
- **dim_val**: 表示每个值向量(Value Vector)的维度大小。
- **dim_attn**: 表示用于计算注意分数的关键字向量(Key Vector 和 Query Vector)的维度大小。
- **n_heads**: 表示有多少个独立的“注意力头”(attention heads),即模型会并行地运行多少次单头注意力。
**初始化过程**:
1. 首先通过循环创建了 `n_heads` 个单独的注意力层,并将它们存入列表 `heads` 中;
```python
self.heads = []
for i in range(n_heads):
self.heads.append(AttentionBlock(dim_val, dim_attn))
```
2. 确保这些注意力层可以被正确识别为 PyTorch Module 组件的一部分,因此将其转换成 `nn.ModuleList()` 对象存储起来。
```python
self.heads = nn.ModuleList(self.heads)
```
3. 定义了一个全连接层 `fc` ,它的作用是整合所有头部的结果并将结果映射到期望的输出维度上:
```python
self.fc = nn.Linear(n_heads * dim_val, dim_val, bias=False)
```
#### 前向传播函数 (`forward`)
```python
def forward(self, x, kv=None):
```
该方法接受输入张量 `x`, 并允许额外提供键值对(`kv`)作为输入数据,默认情况下使用自身作为键和值。
前向传播的具体步骤如下:
1. 将输入传给每一个单独定义好的注意力组件(head),获得多个表示形式;
```python
a = []
for h in self.heads:
a.append(h(x, kv=kv))
```
2. 把从各个头得到的结果堆叠在一起形成一个新的张量:
```python
a = torch.stack(a, dim=-1)
```
3. 展平特定维度的数据结构以便后续处理:
```python
a = a.flatten(start_dim=2)
```
4. 最终把展平后的特征送入线性变换网络以生成最终输出结果:
```python
x = self.fc(a)
return x
```
总结来说,这一段代码实现了标准Transformer架构中的核心部分之一——多头自注意力机制(Multi-head Self-Attention Mechanism)。
---
###
阅读全文
相关推荐


















