class SelfAttention(nn.Module): def __init__(self, embed_dim): super(SelfAttention, self).__init__() self.embed_dim = embed_dim self.query = nn.Linear(embed_dim, embed_dim) self.key = nn.Linear(embed_dim, embed_dim) self.value = nn.Linear(embed_dim, embed_dim) def forward(self, x): q = self.query(x) k = self.key(x) v = self.value(x) attention_weights = F.softmax(torch.matmul(q, k.transpose(1, 2)) / torch.sqrt(torch.tensor(self.embed_dim)), dim=-1) output = torch.matmul(attention_weights, v) return output
时间: 2025-03-09 22:03:05 浏览: 62
<think>嗯,用户想了解如何在PyTorch中自定义一个SelfAttention类。首先,我得回忆一下Self-Attention的基本原理。根据引用[2],Self-Attention机制是用来捕捉输入序列内部的依赖关系的,它通过计算每个位置与其他位置的相关性来生成新的表示。那具体怎么实现呢?
首先,Self-Attention的核心是Q、K、V三个矩阵。输入经过线性变换得到这三个矩阵,然后计算Q和K的点积,缩放后应用softmax得到注意力权重,最后乘以V得到输出。这部分的结构应该用PyTorch的线性层来实现。
然后,用户可能需要知道如何处理多头的情况。不过用户的问题是关于自定义SelfAttention类,所以可能先实现单头,再扩展到多头。不过引用[2]和[3]都提到了MultiHead和Transformer,所以可能用户后续会关心,但现在先聚焦在基础实现上。
接下来要考虑的是输入输出的维度。假设输入是(batch_size, seq_len, embed_dim),线性层应该将embed_dim投影到对应的q、k、v维度。可能需要设置一个隐藏层的大小,比如d_model,然后每个Q、K、V的维度通常是d_model的拆分,比如分成num_heads份。但单头的话,可能直接使用d_model作为每个的维度。
然后计算注意力分数的时候,需要进行缩放,即除以sqrt(d_k),其中d_k是K的维度。这一步很重要,避免点积过大导致softmax梯度消失。所以代码里需要计算这个缩放因子。
另外,注意力权重的softmax应该应用在哪个维度?应该是在最后一个维度,即对每个query对应的所有key的位置进行softmax,所以dim=-1。
还有,是否需要支持mask?比如在解码器中处理可变长度序列时,需要mask掉未来的位置。不过用户的问题没有提到,可能暂时不考虑,但可以提一下扩展的可能性。
最后,实现的时候需要继承nn.Module,定义初始化中的线性层,以及前向传播的逻辑。测试一下输入输出的形状是否正确,梯度是否能正常传播。
可能用户还会关心如何集成到更大的模型中,比如引用[1]中的Unet类使用了注意力机制,可能是在某些层中添加SelfAttention模块。这时候需要注意输入输出的衔接,比如特征图的尺寸是否匹配。
总结一下,步骤应该是:定义线性层生成Q、K、V,计算点积,缩放,softmax,然后乘以V,得到输出。同时,代码中要考虑维度变换和矩阵乘法的顺序,确保计算正确。可能还需要添加残差连接和LayerNorm,但根据问题描述,用户可能只需要基础的SelfAttention实现,高级功能可以作为扩展讨论。</think>以下是用PyTorch实现自定义SelfAttention类的技术详解:
### 一、Self-Attention核心结构
Self-Attention机制通过计算序列内部元素的关联度实现特征增强[^2]。核心组件包含三个线性变换:
$$ Q = XW_Q,\quad K = XW_K,\quad V = XW_V $$
其中$X \in \mathbb{R}^{n \times d_{model}}$为输入张量,$W_Q/W_K/W_V$为可训练参数矩阵
### 二、完整类实现代码
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model, dropout=0.1):
super().__init__()
self.d_k = d_model # 维度缩放因子
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 输入x形状: (batch_size, seq_len, d_model)
q = self.q_linear(x) # (B, L, D)
k = self.k_linear(x) # (B, L, D)
v = self.v_linear(x) # (B, L, D)
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5) # (B, L, L)
# 可选掩码机制
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 归一化与输出
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, v) # (B, L, D)
return output, attn_weights
```
### 三、关键实现细节
1. **维度缩放因子**
点积计算时除以$\sqrt{d_k}$防止梯度消失,经验公式验证有效
2. **掩码机制**
通过`mask`参数控制可见范围,适用于语言模型中的因果掩码
3. **参数共享**
输入输出维度保持一致($d_{model}$),便于残差连接
4. **注意力可视化**
返回`attn_weights`用于分析模型关注点
### 四、进阶优化方向
1. 多头机制:将$d_{model}$拆分为多个子空间并行计算
2. 相对位置编码:在计算注意力分数时加入位置偏置项
3. 稀疏注意力:限制每个token的关注范围提升效率
阅读全文
相关推荐


















