单向自注意力机制
时间: 2025-05-15 20:43:42 浏览: 8
### 单向自注意力机制的概念与实现
单向自注意力机制是一种特殊的注意力机制,主要用于处理具有顺序依赖关系的数据(如自然语言)。它通过限制模型仅关注当前时刻之前的上下文信息,从而保持因果性并防止未来信息泄露。这种设计特别适用于生成任务,例如语言建模或机器翻译。
#### 1. 单向自注意力的核心概念
单向自注意力机制的主要特点是引入 **掩码操作** (Masking),以屏蔽掉未来的词项对当前预测的影响。具体来说,在计算注意力分数时,会使用下三角矩阵作为掩码[^3]。如果某个位置的值被设置为负无穷大,则对应的注意力得分会被置零,从而使该位置的信息不会影响到后续的结果。
#### 2. 数学表达形式
假设输入序列为 \(X = \{x_1, x_2, ..., x_n\}\),其中每个 \(x_i\) 是一个 d 维向量。经过线性变换后得到 Query (\(Q\))、Key (\(K\)) 和 Value (\(V\)):
\[ Q = XW_Q,\ K = XW_K,\ V = XW_V \]
接着计算未加权的注意力分数矩阵 S:
\[ S_{ij} = \frac{\text{Query}_i^\top \cdot \text{Key}_j}{\sqrt{d_k}} \]
为了实现单向特性,需在此基础上施加掩码 M:
\[ S' = S + M \]
这里,\(M\) 是形状相同的矩阵,其上三角部分填充为 -∞,其余部分为 0。最终 softmax 函数作用于修正后的分数矩阵 \(S'\):
\[ A = \text{softmax}(S') \]
最后利用加权求和获得输出 O:
\[ O = AV \]
上述过程确保了任意时间步 t 的输出只取决于前 t-1 步的内容[^4]。
#### 3. PyTorch 实现示例
以下是基于 PyTorch 的简单实现代码片段:
```python
import torch
import torch.nn as nn
import math
class MaskedSelfAttention(nn.Module):
def __init__(self, dim_model, num_heads=8):
super(MaskedSelfAttention, self).__init__()
assert dim_model % num_heads == 0
self.num_heads = num_heads
self.dim_head = dim_model // num_heads
self.query_linear = nn.Linear(dim_model, dim_model)
self.key_linear = nn.Linear(dim_model, dim_model)
self.value_linear = nn.Linear(dim_model, dim_model)
def forward(self, inputs):
batch_size, seq_len, _ = inputs.size()
queries = self.query_linear(inputs).view(batch_size, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
keys = self.key_linear(inputs).view(batch_size, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
values = self.value_linear(inputs).view(batch_size, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(self.dim_head)
mask_shape = (seq_len, seq_len)
subsequent_mask = torch.triu(torch.ones(mask_shape), diagonal=1).bool() # 上三角矩阵
scores.masked_fill_(subsequent_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, values).transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return output
```
此模块定义了一个带有掩码功能的标准多头自注意力层。`torch.triu()` 方法用于创建遮挡未来 token 的掩码矩阵。
---
###
阅读全文
相关推荐

















