大模型学习 (Datawhale_Happy-LLM)笔记5: 搭建一个 Transformer
搭建 Transformer 的核心组件总结
1. 基础功能模块
自注意力机制 :通过 QKV
矩阵计算序列内依赖关系,公式为:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
Attention ( Q , K , V ) = softmax ( d k
Q K T ) V 多头注意力 :将特征拆分为多个子空间并行计算注意力,增强模型表达能力位置编码 :通过正弦余弦函数为序列添加位置信息,解决 Transformer 无序列感知问题
2. 核心网络层
层归一化 (LayerNorm) :对每个样本的所有特征维度归一化,公式为:
LayerNorm
(
x
)
=
α
⊙
x
−
μ
σ
2
+
ϵ
+
β
\text{LayerNorm}(x) = \alpha \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
LayerNorm ( x ) = α ⊙ σ 2 + ϵ
x − μ + β 前馈神经网络 (MLP) :由 Linear+ReLU+Linear
构成,增强模型非线性表达能力
3. 编码器与解码器架构
编码器层 :由 LayerNorm+MultiHeadAttention+残差连接
和 LayerNorm+MLP+残差连接
组成解码器层 :比编码器多一个 掩码多头注意力
,用于避免预测时看到未来信息堆叠结构 :多个编码器/解码器层堆叠形成深度网络
4. 输入输出处理
嵌入层 :将离散token转换为连续向量,与位置编码相加后输入网络输出层 :通过线性层将特征映射到词表空间,用于生成概率分布
5. 关键技术点
残差连接 :解决深度网络训练梯度消失问题Dropout :随机丢弃神经元,防止过拟合掩码机制 :在解码器中屏蔽未来位置,确保预测时的因果关系
完整 Transformer 架构流程图
输入序列 ──→ 嵌入层 + 位置编码 ──→ 编码器序列 ──→ 编码器输出
│
↓
解码器序列(含掩码)
│
↓
线性层 + Softmax ─→ 输出序列
多头自注意力模块
class ModelArgs :
def __init__ ( self) :
self. n_embed = 256
self. n_head = 4
self. head_dim = self. n_embed // self. n_heads
self. dropout = 0.1
self. max_seq_len = 512
self. n_layers = 6
self. vocab_size = None
self. block_size = None
class MultiHeadAttention ( nn. Module) :
def __init__ ( self, args: ModelArgs, is_causal= False ) :
super ( ) . __init__( )
assert args. n_embed % args. n_head == 0
model_parallel_size = 1
self. n_local_heads = args. n_heads // model_parallel_size
self. head_dim = args. dim // args. n_heads
self. wq = nn. Linear( args. dim, args. n_heads * self. head_dim, bias= False )
self. wk = nn. Linear( args. dim, args. n_heads * self. head_dim, bias= False )
self. wv = nn. Linear( args. dim, args. n_heads * self. head_dim, bias= False )
self. wo = nn. Linear( args. n_heads * self. head_dim, args. dim, bias= False )
self. attn_dropout = nn. Dropout( args. dropout)
self. resid_dropout = nn. Dropout( args. dropout)
if is_causal:
mask = torch. full( ( 1 , 1 , args. max_seq_len, args. max_seq_len) , float ( "-inf" ) )
mask = torch. triu( mask, diagonal= 1 )
self. register_buffer( "mask" , mask)
def forward ( self, q: torch. Tensor, k: torch. Tensor, v: torch. Tensor) :
bsz, seqlen, _ = q. shape
xq, xk, xv = self. wq( q) , self. wk( k) , self. wv( v)
xq = xq. view( bsz, seqlen, self. n_local_heads, self. head_dim)
xk = xk. view( bsz, seqlen, self. n_local_heads, self. head_dim)
xv = xv. view( bsz, seqlen, self. n_local_heads, self. head_dim)
xq = xq. transpose( 1 , 2 )
xk = xk. transpose( 1 , 2 )
xv = xv. transpose( 1 , 2 )
scores = torch. matmul( xq, xk. transpose( 2 , 3 ) ) / math. sqrt( self. head_dim)
if self. is_causal:
assert hasattr ( self, 'mask' )
scores = scores + self. mask[ : , : , : seqlen, : seqlen]
scores = F. softmax( scores. float ( ) , dim= - 1 ) . type_as( xq)
scores = self. attn_dropout( scores)
output = torch. matmul( scores, xv)
output = output. transpose( 1 , 2 ) . contiguous( ) . view( bsz, seqlen, - 1 )
output = self. wo( output)
output = self. resid_dropout( output)
return output
前馈神经网络
class MLP ( nn. Module) :
'''前馈神经网络
MLP, (Multi-Layer Perceptron) 多层感知机
用以构建前馈神经网络
'''
def __init__ ( self, dim: int , hidden_dim: int , dropout: float ) :
super ( ) . __init__( )
self. w1 = nn. Linear( dim, hidden_dim, bias= False )
self. w2 = nn. Linear( hidden_dim, dim, bias= False )
self. dropout = nn. Dropout( dropout)
def forward ( self, x) :
return self. dropout( self. w2( F. relu( self. w1( x) ) ) )
LayerNorm 层
归一化操作的数学公式
计算样本均值 (其中,
Z
j
i
Z_j^i
Z j i 是样本i在第j个维度上的值,m就是mini-batch的大小)
μ
j
=
1
m
∑
i
=
1
m
Z
j
i
\displaystyle\mu_j = \frac{1}{m}\sum_{i=1}^{m}Z_j^{i}
μ j = m 1 i = 1 ∑ m Z j i 再计算样本的方差
σ
2
=
1
m
∑
i
=
1
m
(
Z
j
i
−
μ
j
)
2
\displaystyle\sigma^2 = \frac{1}{m}\sum_{i=1}{m}(Z_j^i - \mu_j)^2
σ 2 = m 1 i = 1 ∑ m ( Z j i − μ j ) 2 最后对每个样本的值减去均值再除以标准差来将这个mini-batch的样本分布转化为标准正态分布
Z
~
j
=
Z
j
−
μ
j
σ
2
+
ϵ
\displaystyle\widetilde{Z}_j = \frac{Z_j-\mu_j}{\sqrt{\sigma^2+\epsilon}}
Z
j = σ 2 + ϵ
Z j − μ j (此处
ϵ
\epsilon
ϵ 这一极小量是为了避免分母为零)
class LayerNorm ( nn. Module) :
'''基于上述归一化的式子,实现一个简单的 Layer Norm 层'''
def __init__ ( self, features, eps = 1e - 6 ) :
super ( LayerNorm, self) . __init__( )
self. a_2 = nn. Parameter( torch. ones( features) )
self. b_2 = nn. Parameter( torch. zeros( features) )
self. eps = eps
def forward ( self, x) :
mean = x. mean( - 1 , keepdim= True )
std = x. std( - 1 , keepdim= True )
return self. a_2 * ( x - mean) / ( std + self. eps) + self. b_2
Encoder
class EncoderLayer ( nn. Module) :
'''Encoder 层'''
def __init__ ( self, args) :
super ( ) . __init__( )
self. attention_norm = LayerNorm( args. n_embed)
self. attention = MultiHeadAttention( args, is_causal= False )
self. fnn_norm = LayerNorm( args. n_embed)
self. feed_forward = MLP( dim= args. n_embed, hidden_dim= args. n_embed* 4 , dropout= args. dropout)
def forward ( self, x) :
x = self. attention_norm( x)
h = x + self. attention. forward( x, x, x)
out = h + self. feed_forward. forward( self. fnn_norm( h) )
return out
class Encoder ( nn. Module) :
'''Encoder 块'''
def __init__ ( self, args) :
super ( Encoder, self) . __init__( )
self. layers = nn. ModuleList( [ EncoderLayer( args) for _ in range ( args. n_layer) ] )
self. norm = LayerNorm( args. n_embed)
def forward ( self, x) :
for layer in self. layers:
x = layer( x)
return self. norm( x)
Decoder
class DecoderLayer ( nn. Module) :
'''解码层'''
def __init__ ( self, args) :
super ( ) . __init__( )
self. attention_norm_1 = LayerNorm( args. n_embed)
self. mask_attention = MultiHeadAttention( args, is_causal= True )
self. attention_norm2 = LayerNorm( args. n_embed)
self. attention = MultiHeadAttention( args, is_causal= False )
self. ffn_norm = LayerNorm( args. n_embed)
self. feed_forward = MLP( args)
def forward ( self, x, enc_out) :
x = self. attention_norm_1( x)
x = x + self. mask_attention. forward( x, x, x)
x = self. attention_norm_2( x)
h = x + self. attention. forward( x, enc_out, enc_out)
out = h + self. feed_forward. forward( self. ffn_norm( h) )
return out
class Decoder ( nn. Module) :
'''解码器'''
def __init__ ( self, args) :
super ( Decoder, self) . __init__( )
self. layers = nn. ModuleList( [ DecoderLayer( args) for _ in range ( args. n_layer) ] )
self. norm = LayerNorm( args. n_embed)
def forward ( self, x, enc_out) :
for layer in self. layers:
x = layer( x, enc_out)
return self. norm( x)
搭建一个 Transformer
class PositionalEncoding ( nn. Module) :
def __init__ ( self, args) :
super ( PositionalEncoding, self) . __init__( )
self. dropout = nn. Dropout( p= args. dropout)
pe = torch. zeros( args. block_size, args. n_embed)
position = torch. arange( 0 , args. block_size) . unsqueeze( 1 )
div_term = torch. exp(
torch. arange( 0 , args. n_embed, 2 ) * - ( math. log( 10000.0 ) / args. n_embed)
)
pe[ : , 0 : : 2 ] = torch. sin( position * div_term)
pe[ : , 1 : : 2 ] = torch. cos( position * div_term)
pe = pe. unsqueeze( 0 )
self. register_buffer( "pe" , pe)
def forward ( self, x) :
x = x + self. pe[ : , : x. size( 1 ) ] . requires_grad_( False )
return self. dropout( x)
一个完整的 Transformer
class Transformer ( nn. Module) :
"""整体模型"""
def __init__ ( self, args) :
super ( ) . __init__( )
assert args. vocab_size is not None
assert args. block_size is not None
self. args = args
self. transformer = nn. ModuleDict( dict (
wte = nn. Embedding( args. vocab_size, args. n_embed) ,
wpe = PositionalEncoding( args) ,
drop = nn. Dropout( args. dropout) ,
encoder = Encoder( args) ,
decoder = Decoder( args) ,
) )
self. lm_head = nn. Linear( args. n_embed, args. vocab_size, bias= False )
self. apply ( self. _init_weights)
print ( f'number of parameters: { self. get_num_params( ) / 1e6 : .2fM } ' )
def get_num_params ( self, non_embedding= False ) :
"""统计所有参数的数量"""
n_params = sum ( p. numel( ) for p in self. parameters( ) )
if non_embedding:
n_params -= self. transformer. wpe. weight. numel( )
return n_params
def _init_weights ( self, module) :
"""初始化权重"""
if isinstance ( module, nn. Linear) :
torch. nn. init. normal_( module. weight, mean= 0.0 , std= 0.02 )
if module. bias is not None :
torch. nn. init. zeros_( module. bias)
elif isinstance ( module, nn. Embedding) :
torch. nn. init. normal_( module. weight, mean= 0.0 , std= 0.02 )
def forward ( self, idx, targets= None ) :
device = idx. device
b, t = idx. size( )
assert t <= self. args. block_size, f'不能计算该序列,该序列长度为 { t} , 最大序列长度只有 { self. args. block_size} '
print ( f'idx: { idx. size( ) } ' )
tok_emb = self. transformer. wte( idx)
print ( f"tok_emb: { tok_emb. size( ) } " )
pos_emb = self. transformer. wpe( tok_emb)
x = self. transformer. drop( pos_emb)
print ( f'x after wpe: { x. size( ) } ' )
if targets is not None :
logits = self. lm_head( x)
loss = F. cross_entropy(
logits. view( - 1 , logits. size( - 1 ) ) ,
targets. view( - 1 ) ,
ignore_index= - 1 )
else :
logits = self. lm_head( x[ : , [ - 1 ] , : ] )
loss = None
return logits, loss