目录
层规范化(LayerNorm) 和批规范化(BatchNorm1d)之间的差异
1. 层规范化(LayerNorm):“按学生自己的水平调整”
2. 批规范化(BatchNorm1d):“按全班同一科目的水平调整”
代码逻辑详解:附录:增加Transformer代码可读性-CSDN博客
Transformer 模型
Transformer 模型是一种基于自注意力机制的深度学习模型,在自然语言处理等领域影响深远。详细解析如下:
一、核心优势:并行计算与全局依赖
与传统的循环神经网络(RNN)不同,Transformer 完全抛弃了 “按顺序处理序列” 的模式,通过自注意力机制让序列中每个元素能直接 “看到” 其他所有元素,既解决了 RNN 难以捕捉长距离依赖的问题,又支持并行计算,大幅提升了训练效率。
二、整体结构:编码器 - 解码器框架
模型由编码器(Encoder) 和解码器(Decoder) 两部分组成,两者均由多个相同的 “层” 堆叠而成(原始论文中各用了 6 层)。
1. 编码器(Encoder):处理输入序列
作用是将输入序列(如 “我喜欢机器学习”)转换为包含上下文信息的向量表示,供解码器使用。
每层包含两个核心子层,且每个子层后都有残差连接(输入直接加到输出,防止梯度消失)和层归一化(让数据分布更稳定,加速训练):
- 子层 1:多头自注意力(Multi-Head Self-Attention)
让序列中每个元素(如 “喜欢”)同时关注序列中其他所有元素(如 “我”“机器学习”),捕捉上下文关系。
“多头” 指将输入分成多个小份(如 8 个头),每个头独立计算注意力,再将结果合并 —— 相当于从不同角度捕捉依赖(比如一个头关注语法,一个头关注语义)。 - 子层 2:前馈神经网络(Feed-Forward Network)
对每个元素的向量进行独立的非线性变换(先升维再降维,加 ReLU 激活),增强模型对局部特征的捕捉能力。
2. 解码器(Decoder):生成输出序列
作用是根据编码器的输出和已生成的部分结果(如翻译任务中已生成的 “I like”),逐步生成完整的目标序列(如 “ I like machine learning ”)。每层包含三个子层(同样带残差连接和层归一化):
- 子层 1:掩码多头自注意力(Masked Multi-Head Self-Attention)
功能类似编码器的自注意力,但额外加入 “掩码”—— 让每个位置只能关注 “前面已生成的元素”,不能偷看 “后面还没生成的元素”(比如生成 “like” 时,只能参考 “I”,不能提前用 “machine”),保证生成逻辑的合理性。 - 子层 2:编码器 - 解码器注意力(Encoder-Decoder Attention)
让解码器 “关注编码器的输出”,即结合输入序列的信息来生成目标序列(比如翻译时,“机器学习” 对应 “machine learning”,需要解码器关注编码器对 “机器学习” 的向量表示)。 - 子层 3:前馈神经网络
与编码器的前馈网络功能相同,对当前子层的输出做进一步特征转换。
三、输入处理:嵌入与位置编码
1. 词嵌入(Embedding)
将输入的离散符号(如单词、字)转换为连续的向量(比如 “猫”→ [0.2, 0.5, -0.3, ...]),让计算机能理解语义。这些向量可以是预训练好的(如 GloVe),也可以在模型训练中从头学习。
2. 位置编码(Positional Encoding)
由于 Transformer 没有 RNN 的 “顺序记忆”,必须手动加入位置信息 —— 通过特定规则生成 “位置向量”,与词嵌入相加,让模型知道 “哪个词在前,哪个词在后”(比如 “我爱你” 和 “你爱我” 的区别)。
原始实现中用了不同频率的正弦 / 余弦函数来生成位置向量,确保不同位置的向量有明显差异,且能体现相对位置关系。
四、注意力机制:核心中的核心
注意力机制的本质是:为序列中的每个元素计算一个 “权重分布”,表示它与其他元素的关联程度,再通过加权求和得到新的向量表示。
1. 自注意力(Self-Attention)
以 “我喜欢机器学习” 为例:
- 每个词(如 “喜欢”)会生成三个向量:查询(Query)(“我要找什么?”)、键(Key)(“我是什么?”)、值(Value)(“我的信息是什么?”)。
- 通过 “查询” 与其他词的 “键” 计算相似度(如 “喜欢” 的 Query 和 “我” 的 Key 相似度高),得到权重(“喜欢” 更关注 “我”)。
- 用权重对其他词的 “值” 加权求和,得到 “喜欢” 的新向量(融合了 “我” 和 “机器学习” 的信息)。
2. 多头注意力(Multi-Head)
将输入分成多个 “头”(如 8 个),每个头独立计算自注意力,再将结果拼接起来。这样做的好处是:每个头可以捕捉不同类型的依赖(比如一个头关注语法关系,一个头关注语义关联),提升模型的表达能力。
3. 掩码注意力(Masked Attention)
仅用于解码器的第一个子层,通过在计算相似度时 “掩盖” 未来位置(如将未来词的权重设为负无穷),确保生成时不依赖未生成的信息,避免逻辑混乱。
五、输出层:生成目标序列
解码器最后一层的输出会经过一个线性层(将向量维度映射到目标词汇表大小)和Softmax 函数,得到每个词的概率分布(如生成 “learning” 的概率是 0.8,“study” 是 0.1),再通过贪心搜索或束搜索(Beam Search)选择最可能的词,逐步生成完整序列。
六、模型架构
层规范化(LayerNorm) 和批规范化(BatchNorm1d)之间的差异
一、带通俗注解的代码
# 假设我们要对比两种“给成绩单打分”的方式:按学生个人调整(层规范化)和按全班科目调整(批规范化)
# 层规范化(LayerNorm):针对每个学生的所有科目成绩做调整(关注单个样本内部)
# 参数2表示每个学生有2门科目
ln = nn.LayerNorm(2)
# 批规范化(BatchNorm1d):针对全班同学的同一门科目成绩做调整(关注整个批次的同一特征)
# 参数2表示有2门科目
bn = nn.BatchNorm1d(2)
# 测试数据:2个学生(样本),每个学生2门科目成绩(特征)
# 比如:学生A的语文1分、数学2分;学生B的语文2分、数学3分
X = torch.tensor([[1, 2], # 学生A的成绩:[语文, 数学]
[2, 3]], dtype=torch.float32) # 学生B的成绩:[语文, 数学]
# 打印两种方式调整后的成绩
print('层规范化(按每个学生自己的情况调整):', ln(X),
'\n批规范化(按全班同一科目的情况调整):', bn(X))
二、核心区别:“调整的参照物” 不同
1. 层规范化(LayerNorm):“按学生自己的水平调整”
-
计算逻辑:对每个学生的所有科目,用 “自己的平均分和波动” 来调整。
比如学生 A 的成绩是 [1,2]:- 自己的平均分是 1.5,波动(方差)很小
- 调整后:把 1 分变成 “低于自己平均分 0.5”,2 分变成 “高于自己平均分 0.5”(对应代码输出里的 [-1, 1])
学生 B 的成绩是 [2,3]:
- 自己的平均分是 2.5
- 调整后:2 分变成 “低于自己平均分 0.5”,3 分变成 “高于自己平均分 0.5”(也是 [-1, 1])
2. 批规范化(BatchNorm1d):“按全班同一科目的水平调整”
-
计算逻辑:对全班同学的同一科目,用 “全班的平均分和波动” 来调整。
比如语文科目的成绩是 [1(A), 2(B)]:数学科目的成绩是 [2(A), 3(B)]:
- 全班语文平均分 1.5,波动很小
- 调整后:A 的 1 分变成 “低于全班平均分 0.5”,B 的 2 分变成 “高于全班平均分 0.5”(对应代码输出第一列的 [-1, 1])
- 全班数学平均分 2.5
- 调整后:A 的 2 分变成 “低于全班平均分 0.5”,B 的 3 分变成 “高于全班平均分 0.5”(对应代码输出第二列的 [-1, 1])
3.结果
残差连接
一、残差连接的直观理解
想象你在训练一个深层网络,假设某一层的输入是X
,这一层需要学习一个复杂的映射F(X)
(比如通过卷积、注意力机制等)。
- 没有残差连接时:这一层的输出是
F(X)
,网络需要直接学习从X
到目标的完整映射。 - 有残差连接时:这一层的输出是
X + F(X)
(即输入X
直接 “跳过” 该层,与该层的学习结果F(X)
相加)。
此时,网络只需要学习输入与目标之间的差异(即F(X)
),而不是完整映射。如果F(X)
学到的是 0,那么输出就等于输入X
,信息不会丢失 —— 这就是 “残差” 的含义(学习的是 “残留的差异”)。
二、残差连接的核心作用
-
缓解梯度消失问题
深层网络训练时,梯度需要从后向前传播。没有残差连接时,梯度经过多层乘法后容易变得极小(梯度消失),导致浅层参数难以更新。
有残差连接时,梯度可以通过 “捷径” 直接传递到浅层(因为X + F(X)
对X
的梯度是 1,不会衰减),使得深层网络(如 Transformer 的几十层、ResNet 的几百层)能够被有效训练。 -
保留原始信息,专注学习差异
网络可以选择 “保留输入”(如果F(X)
接近 0)或 “修正输入”(如果F(X)
有意义)。这种灵活性让模型在学习复杂特征的同时,不会丢失底层的基础信息。例如在 Transformer 中,注意力机制的输出会与原始输入相加,确保重要的基础特征不被复杂计算覆盖。
三、代码中的残差连接示例
以之前的AddNorm
类为例:
return self.ln(self.dropout(Y) + X) # Y是子层输出,X是原始输入,两者相加即残差连接
这里的X + dropout(Y)
就是典型的残差连接:X
直接跳过子层,与子层的学习结果Y
(经过 dropout 防止过拟合)相加。
四、类比理解
可以把残差连接想象成公路上的 “应急通道”:
- 主路(
F(X)
)负责复杂的 “特色路线”(学习复杂特征); - 应急通道(
X
)确保即使主路拥堵(F(X)
没学到有效信息),车辆(信息)也能顺利到达目的地(输出层)。
这种设计让深层网络既能探索复杂模式,又能保证基础信息的稳定传递。
完整代码
"""
文件名: 10.7
作者: 墨尘
日期: 2025/7/17
项目名: dl_env
备注: 完整实现Transformer模型(含编码器、解码器、核心子层),并测试其在机器翻译任务上的功能
"""
import math
import torch
from torch import nn
import pandas as pd # 用于处理注意力权重的缺失值
from d2l import torch as d2l # 依赖d2l库的工具函数(如多头注意力、数据加载)
import matplotlib.pyplot as plt # 用于可视化注意力权重热图
# -------------------------- 1. 基于位置的前馈网络(PositionWiseFFN) --------------------------
# 作用:对序列中每个位置的特征向量独立进行非线性变换,增强模型对局部特征的表达能力
# 地位:Transformer中注意力机制后必接的子层,为特征添加非线性交互
class PositionWiseFFN(nn.Module):
"""基于位置的前馈网络(Transformer子层)"""
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
"""
初始化前馈网络
参数详解:
ffn_num_input: 输入特征维度(需与注意力机制输出维度一致)
ffn_num_hiddens: 隐藏层维度(通常大于输入维度,形成"升维-降维"结构)
ffn_num_outputs: 输出特征维度(需与输入维度一致,才能参与残差连接)
"""
super(PositionWiseFFN, self).__init__(** kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens) # 第一层线性变换(升维)
self.relu = nn.ReLU() # 非线性激活函数,引入特征间的交互
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs) # 第二层线性变换(降维)
def forward(self, X):
"""
前向传播:对序列中每个位置的特征独立应用相同的MLP
参数:
X: 输入张量,形状为(batch_size, seq_len, feature_dim)(batch_size:样本数;seq_len:序列长度)
返回:
输出张量,形状与X一致(保持序列长度和样本数,仅特征维度经非线性变换)
"""
# 计算流程:输入 → 升维(增加特征交互能力) → 非线性激活(引入非线性) → 降维(恢复原特征维度)
return self.dense2(self.relu(self.dense1(X)))
# -------------------------- 2. 残差连接 + 层规范化(AddNorm) --------------------------
# 作用:Transformer中每个子层(注意力/前馈网络)的标配输出处理,解决深层网络训练难题
# 核心逻辑:通过残差连接保留原始信息,通过层规范化稳定特征分布,使模型可训练数百层
class AddNorm(nn.Module):
"""残差连接后进行层规范化(Transformer子层输出的标准处理)"""
def __init__(self, normalized_shape, dropout, **kwargs):
"""
初始化参数
参数详解:
normalized_shape: 层规范化的维度(通常为输入特征的最后一维,如[seq_len, feature_dim])
dropout: Dropout概率(随机丢弃部分特征,防止过拟合)
"""
super(AddNorm, self).__init__(** kwargs)
self.dropout = nn.Dropout(dropout) # Dropout层,仅作用于子层输出(保护原始输入)
self.ln = nn.LayerNorm(normalized_shape) # 层规范化层(对每个样本独立归一化,适合序列数据)
def forward(self, X, Y):
"""
前向传播:先残差连接,再层规范化
参数:
X: 子层的原始输入张量(形状与Y必须一致,否则无法相加)
Y: 子层(如注意力机制/前馈网络)的输出张量
返回:
经过处理的张量(形状与X/Y一致,特征分布更稳定)
"""
# 步骤解析:
# 1. 对Y应用Dropout:随机丢弃部分特征,防止模型过度依赖子层输出
# 2. 残差连接(X + dropout(Y)):保留原始输入信息,缓解梯度消失(若Y无效,输出≈X)
# 3. 层规范化:对每个样本计算均值和方差,将特征缩放到标准分布,加速训练
return self.ln(self.dropout(Y) + X)
# -------------------------- 3. Transformer编码器块(EncoderBlock) --------------------------
# 作用:Transformer编码器的基本单元,堆叠N次形成完整编码器,负责捕获输入序列的内部依赖(如句子中词与词的关联)
# 结构:多头自注意力子层 + 前馈网络子层,每个子层后均接AddNorm
class EncoderBlock(nn.Module):
"""Transformer编码器块(构成完整编码器的基本单元)"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, use_bias=False, **kwargs):
"""
初始化编码器块
参数详解:
key_size/query_size/value_size: 注意力机制中"键/查询/值"的特征维度
num_hiddens: 隐藏层特征维度(与输入序列的特征维度一致)
norm_shape: AddNorm中层规范化的维度(与输入特征维度匹配)
ffn_num_input/ffn_num_hiddens: 前馈网络的输入/隐藏层维度
num_heads: 多头注意力的头数(将特征拆分为多头并行计算,捕获多尺度依赖)
dropout: Dropout概率(用于注意力和前馈网络)
use_bias: 线性层是否使用偏置(控制模型复杂度)
"""
super(EncoderBlock, self).__init__(** kwargs)
# 子层1:多头自注意力机制(键/查询/值均为输入X,即"自注意力",捕获序列内依赖)
self.attention = d2l.MultiHeadAttention(
# 关键修正:用关键字参数传递,避免位置参数顺序错误(兼容不同版本d2l库)
key_size=key_size,
query_size=query_size,
value_size=value_size,
num_hiddens=num_hiddens,
num_heads=num_heads,
dropout=dropout,
use_bias=use_bias
)
# 子层1的输出处理:残差连接 + 层规范化
self.addnorm1 = AddNorm(norm_shape, dropout)
# 子层2:基于位置的前馈网络(对注意力输出做非线性变换,增强特征表达)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
# 子层2的输出处理:残差连接 + 层规范化
self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
"""
前向传播:多头自注意力 → AddNorm → 前馈网络 → AddNorm
参数:
X: 输入序列张量,形状为(batch_size, seq_len, num_hiddens)
valid_lens: 有效长度张量(形状为(batch_size,)),用于屏蔽序列中的无效位置(如填充的PAD符号)
返回:
经过编码器块处理后的张量(形状与X一致,已捕获序列内各位置的依赖关系)
"""
# 步骤1:通过多头自注意力捕获序列内依赖(如"我"与"家"的关联)
# 自注意力中,查询=键=值=X,valid_lens控制仅关注有效位置(忽略PAD)
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
# 步骤2:通过前馈网络对注意力输出做非线性变换,增强特征表达能力
return self.addnorm2(Y, self.ffn(Y))
# -------------------------- 4. 完整Transformer编码器(TransformerEncoder) --------------------------
# 作用:将输入序列(如英文句子)编码为包含上下文信息的特征向量
# 结构:词嵌入 + 位置编码 + N个EncoderBlock堆叠
class TransformerEncoder(d2l.Encoder):
"""Transformer编码器"""
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(** kwargs)
self.num_hiddens = num_hiddens # 隐藏层维度(与词嵌入维度一致)
self.embedding = nn.Embedding(vocab_size, num_hiddens) # 词嵌入层(将词ID转为向量)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout) # 位置编码(注入序列顺序信息)
# 堆叠num_layers个编码器块(层数越多,捕获的上下文越复杂)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, use_bias))
def forward(self, X, valid_lens, *args):
"""
前向传播:词嵌入 → 位置编码 → 多层编码器块
参数:
X: 输入词ID序列,形状为(batch_size, seq_len)
valid_lens: 有效长度张量(屏蔽无效位置)
返回:
编码后的特征向量(形状为(batch_size, seq_len, num_hiddens))
"""
# 1. 词嵌入:将词ID转为向量,并缩放(平衡与位置编码的量级)
X = self.embedding(X) * math.sqrt(self.num_hiddens)
# 2. 注入位置编码(Transformer无循环结构,需显式添加位置信息)
X = self.pos_encoding(X)
# 3. 经过所有编码器块(逐层捕获更复杂的上下文)
self.attention_weights = [None] * len(self.blks) # 存储各层注意力权重(用于可视化)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[i] = blk.attention.attention.attention_weights # 保存注意力权重
return X
# -------------------------- 5. Transformer解码器块(DecoderBlock) --------------------------
# 作用:Transformer解码器的基本单元,堆叠N次形成完整解码器,负责将编码器输出转换为目标序列(如法语句子)
# 结构:掩蔽多头自注意力(防未来信息泄露) + 编码器-解码器注意力(结合源序列信息) + 前馈网络,均接AddNorm
class DecoderBlock(nn.Module):
"""解码器中第i个块"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, i, **kwargs):
super(DecoderBlock, self).__init__(** kwargs)
self.i = i # 记录当前是第几个解码器块(用于管理历史状态)
# 子层1:掩蔽多头自注意力(防止解码器关注"未来"位置,如翻译时先译的词不能看后译的词)
self.attention1 = d2l.MultiHeadAttention(
key_size=key_size,
query_size=query_size,
value_size=value_size,
num_hiddens=num_hiddens,
num_heads=num_heads,
dropout=dropout
)
self.addnorm1 = AddNorm(norm_shape, dropout)
# 子层2:编码器-解码器注意力(结合编码器输出与解码器当前输出,如用英文信息指导法语翻译)
self.attention2 = d2l.MultiHeadAttention(
key_size=key_size,
query_size=query_size,
value_size=value_size,
num_hiddens=num_hiddens,
num_heads=num_heads,
dropout=dropout
)
self.addnorm2 = AddNorm(norm_shape, dropout)
# 子层3:前馈网络(增强特征表达)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def forward(self, X, state):
"""
前向传播:掩蔽自注意力 → AddNorm → 编码器-解码器注意力 → AddNorm → 前馈网络 → AddNorm
参数:
X: 解码器输入序列,形状为(batch_size, seq_len, num_hiddens)
state: 包含编码器输出、编码器有效长度、解码器历史状态(用于推理时累积前文)
返回:
解码器输出 + 更新后的state(包含历史状态,用于下一时间步解码)
"""
enc_outputs, enc_valid_lens = state[0], state[1] # 编码器输出和有效长度
# 管理解码器历史状态(推理时需累积已解码的词,训练时直接用完整序列)
if state[2][self.i] is None:
key_values = X # 训练时:键/值=完整输入序列
else:
key_values = torch.cat((state[2][self.i], X), axis=1) # 推理时:拼接历史序列(已解码的词)
state[2][self.i] = key_values # 更新历史状态
# 训练时:生成掩蔽(下三角矩阵),防止关注未来位置(如第3个词不能看第4个词)
if self.training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None # 推理时无需掩蔽(每次仅解码一个词)
# 子层1:掩蔽自注意力(确保解码顺序正确,不泄露未来信息)
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# 子层2:编码器-解码器注意力(用编码器输出指导解码,如英文"home"对应法语"chez moi")
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
# 子层3:前馈网络增强特征表达
return self.addnorm3(Z, self.ffn(Z)), state
# -------------------------- 6. 完整Transformer解码器(TransformerDecoder) --------------------------
# 作用:将编码器输出转换为目标序列(如将英文编码转换为法语)
# 结构:词嵌入 + 位置编码 + N个DecoderBlock堆叠 + 输出层(映射到目标词汇表)
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(** kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers # 解码器块数量
self.embedding = nn.Embedding(vocab_size, num_hiddens) # 目标语言词嵌入
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout) # 目标序列位置编码
# 堆叠num_layers个解码器块
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size) # 输出层(映射到目标词汇表)
def init_state(self, enc_outputs, enc_valid_lens, *args):
"""初始化解码器状态(包含编码器输出、有效长度、历史状态容器)"""
return [enc_outputs, enc_valid_lens, [None] * self.num_layers] # [None]*num_layers用于存储各块历史
def forward(self, X, state):
"""
前向传播:词嵌入 → 位置编码 → 多层解码器块 → 输出层
参数:
X: 目标序列词ID,形状为(batch_size, seq_len)
state: 解码器初始化状态(来自编码器输出)
返回:
词汇表概率分布(形状为(batch_size, seq_len, vocab_size)) + 更新后的state
"""
# 1. 词嵌入 + 位置编码(注入目标序列的位置信息)
X = self.embedding(X) * math.sqrt(self.num_hiddens)
X = self.pos_encoding(X)
# 2. 经过所有解码器块
self._attention_weights = [[None] * len(self.blks) for _ in range(2)] # 存储注意力权重(自注意力+编码器-解码器注意力)
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# 保存自注意力权重和解码器-编码器注意力权重(用于可视化)
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
# 3. 输出层:映射到目标词汇表(未用softmax,训练时结合交叉熵损失)
return self.dense(X), state
@property
def attention_weights(self):
"""返回注意力权重(用于可视化)"""
return self._attention_weights
# -------------------------- 7. 自定义编码器-解码器模型(解决返回值不匹配问题) --------------------------
class CustomEncoderDecoder(nn.Module):
"""自定义编码器-解码器模型,确保返回值结构与d2l.train_seq2seq兼容"""
def __init__(self, encoder, decoder):
super(CustomEncoderDecoder, self).__init__()
self.encoder = encoder # 编码器实例
self.decoder = decoder # 解码器实例
def forward(self, enc_X, dec_X, enc_valid_lens=None):
"""
前向传播:编码器编码 → 解码器解码
参数:
enc_X: 源序列(如英文句子词ID)
dec_X: 目标序列(如法语句子词ID,用于训练时的"强制教学")
enc_valid_lens: 源序列有效长度
返回:
解码器输出(词汇表概率分布) + 解码器状态(确保与训练函数兼容)
"""
enc_outputs = self.encoder(enc_X, enc_valid_lens) # 编码器输出
dec_state = self.decoder.init_state(enc_outputs, enc_valid_lens) # 初始化解码器状态
return self.decoder(dec_X, dec_state) # 解码器输出(与训练函数期望的"(Y_hat, state)"结构匹配)
# -------------------------- 8. 测试函数(验证各组件功能及完整模型训练) --------------------------
def main():
"""测试各组件的输入输出形状及完整模型在机器翻译任务上的功能"""
# 测试1:PositionWiseFFN(基于位置的前馈网络)
print("===== 测试PositionWiseFFN =====")
ffn = PositionWiseFFN(4, 4, 8) # 输入4维,隐藏层4维,输出8维
ffn.eval() # 评估模式(关闭Dropout,输出可复现)
input_ffn = torch.ones((2, 3, 4)) # 2个样本,序列长度3,每个位置4维特征(全1张量便于观察)
output_ffn = ffn(input_ffn)
print(f"前馈网络输入形状: {input_ffn.shape}") # 预期:torch.Size([2, 3, 4])
print(f"前馈网络输出形状: {output_ffn.shape}") # 预期:torch.Size([2, 3, 8])(特征维度变为8)
print("第一个样本的输出特征(每个位置维度转为8):")
print(output_ffn[0], "\n") # 打印第一个样本的3个位置特征
# 测试2:对比LayerNorm和BatchNorm1d(层规范化 vs 批规范化)
print("===== 测试LayerNorm vs BatchNorm1d =====")
ln = nn.LayerNorm(2) # 层规范化(按样本内特征计算均值/方差)
bn = nn.BatchNorm1d(2) # 批规范化(按批次内同一特征计算均值/方差)
X_norm = torch.tensor([[1.0, 2.0], [2.0, 3.0]], dtype=torch.float32) # 2个样本,每个2维特征
print(f"测试输入:\n{X_norm}")
print(f"layer norm输出:\n{ln(X_norm)}") # 按行(样本内)规范化:每个样本的特征被归一化
print(f"batch norm输出:\n{bn(X_norm)}", "\n") # 按列(特征内)规范化:每个特征的样本被归一化
# 测试3:AddNorm(残差连接 + 层规范化)
print("===== 测试AddNorm =====")
add_norm = AddNorm([3, 4], 0.5) # 规范化维度[3,4],dropout=0.5
add_norm.eval() # 评估模式(Dropout无效)
input_X = torch.ones((2, 3, 4)) # 输入X
input_Y = torch.ones((2, 3, 4)) # 子层输出Y(与X同形状)
output_addnorm = add_norm(input_X, input_Y)
print(f"AddNorm输入形状: {input_X.shape}") # 预期:torch.Size([2, 3, 4])
print(f"AddNorm输出形状: {output_addnorm.shape}", "\n") # 预期:torch.Size([2, 3, 4])(形状不变)
# 测试4:EncoderBlock(Transformer编码器块)
print("===== 测试EncoderBlock =====")
X = torch.ones((2, 100, 24)) # 2个样本,序列长度100,每个位置24维特征
valid_lens = torch.tensor([3, 2]) # 样本1前3个位置有效,样本2前2个位置有效
# 创建编码器块(参数与输入维度匹配)
encoder_blk = EncoderBlock(
key_size=24, query_size=24, value_size=24, num_hiddens=24,
norm_shape=[100, 24], ffn_num_input=24, ffn_num_hiddens=48,
num_heads=8, dropout=0.5)
encoder_blk.eval()
output_enc = encoder_blk(X, valid_lens)
print(f"编码器块输入形状: {X.shape}") # 预期:torch.Size([2, 100, 24])
print(f"编码器块输出形状: {output_enc.shape}", "\n") # 预期:torch.Size([2, 100, 24])(形状不变)
# 测试5:完整编码器(TransformerEncoder)
print("===== 测试TransformerEncoder =====")
encoder = TransformerEncoder(
200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5) # 词汇表大小200,2层编码器块
encoder.eval()
# 输入:2个样本,序列长度100的词ID(值为1)
print(encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape) # 预期:torch.Size([2, 100, 24])
# 测试6:解码器块(DecoderBlock)
print("===== 测试DecoderBlock =====")
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0) # 第0个解码器块
decoder_blk.eval()
X = torch.ones((2, 100, 24)) # 解码器输入
state = [encoder_blk(X, valid_lens), valid_lens, [None]] # 解码器初始状态
print(decoder_blk(X, state)[0].shape) # 预期:torch.Size([2, 100, 24])(解码器输出形状)
# 测试7:完整模型训练(机器翻译任务)
print("\n===== 训练完整Transformer模型(机器翻译) =====")
# 超参数设置
num_hiddens, num_layers, dropout = 32, 2, 0.1 # 隐藏层维度32,2层编码器/解码器
batch_size, num_steps = 64, 10 # 批次大小64,序列长度10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu() # 学习率0.005,训练200轮,使用GPU
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4 # 前馈网络参数,多头注意力头数4
key_size, query_size, value_size = 32, 32, 32 # 注意力机制参数
norm_shape = [32] # 层规范化维度
# 加载机器翻译数据集(英语→法语)
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
# 创建编码器和解码器
encoder = TransformerEncoder(
len(src_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
num_layers, dropout)
decoder = TransformerDecoder(
len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
num_layers, dropout)
# 创建完整模型(使用自定义编码器-解码器,确保返回值结构正确)
net = CustomEncoderDecoder(encoder, decoder)
# 训练模型(使用d2l库的序列到序列训练函数)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
# 测试8:模型翻译效果(英→法)及BLEU评分
print("\n===== 翻译结果及BLEU评分 =====")
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] # 英文测试句子
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] # 对应的法语参考翻译
for eng, fra in zip(engs, fras):
# 翻译并获取注意力权重
translation, dec_attention_weight_seq = d2l.predict_seq2seq(
net, eng, src_vocab, tgt_vocab, num_steps, device, True)
# 打印翻译结果和BLEU评分(评估翻译质量,越高越好)
print(f'{eng} => {translation}, BLEU评分 {d2l.bleu(translation, fra, k=2):.3f}')
# 测试9:可视化编码器注意力权重(展示模型关注的位置)
print("\n===== 可视化编码器注意力权重 =====")
# 处理编码器注意力权重(调整形状用于可视化)
enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape(
(num_layers, num_heads, -1, num_steps))
print(f"编码器注意力权重形状: {enc_attention_weights.shape}") # 形状:(层数, 头数, 序列长度, 序列长度)
# 绘制注意力热图(每行:查询位置;每列:键位置,颜色越深表示关注度越高)
d2l.show_heatmaps(
enc_attention_weights.cpu(), xlabel='Key positions(键位置)',
ylabel='Query positions(查询位置)', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
plt.show()
# 测试10:可视化解码器注意力权重
print("\n===== 可视化解码器注意力权重 =====")
# 处理解码器注意力权重(填充缺失值)
dec_attention_weights_2d = [head[0].tolist()
for step in dec_attention_weight_seq
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = torch.tensor(
pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
# 调整形状:(序列长度, 2, 层数, 头数, 序列长度)
dec_attention_weights = dec_attention_weights_filled.reshape(
(-1, 2, num_layers, num_heads, num_steps))
# 分离自注意力和编码器-解码器注意力
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.permute(1, 2, 3, 0, 4)
print(f"解码器自注意力权重形状: {dec_self_attention_weights.shape}")
print(f"编码器-解码器注意力权重形状: {dec_inter_attention_weights.shape}")
# 绘制解码器自注意力热图
d2l.show_heatmaps(
dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
plt.show()
# 绘制编码器-解码器注意力热图(展示目标词对源词的关注)
d2l.show_heatmaps(
dec_inter_attention_weights, xlabel='Key positions(源序列位置)',
ylabel='Query positions(目标序列位置)', titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
plt.show()
if __name__ == '__main__':
main()