TokenFormer模块
论文《TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters》
论文地址: https://arxiv.org/abs/2410.23168
发表期刊: ICLR
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
详细代码见文章最后
1、整体框架
图1. TokenFormer整体框架
TokenFormer是一个原生可扩展的Transformer架构,通过将模型参数视为token来彻底重新思考Transformer的扩展方式。该模块不仅在输入token之间使用注意力机制进行计算,还在token与模型参数之间进行交互,从而增强架构灵活性。TokenFormer用token-参数注意力层(Pattention)替换所有线性投影,其中输入token作为查询,模型参数作为键和值。这种重新表述允许渐进式高效扩展,无需从头重新训练。该模块通过增量添加新的键值参数对,可以从124M扩展到1.4B参数,实现与从头训练的Transformer相当的性能,同时大大降低训练成本。实验表明,TokenFormer在语言建模和视觉任务上都达到了竞争性能能,为大规模模型训练提供了新的范式。
2、机制
1、Token-参数注意力(Pattention)层:
TokenFormer的核心创新是Pattention层,它引入了一组可训练的token作为模型参数,然后使用交叉注意力管理输入token与这些参数token之间的交互。给定输入token I∈ℝ^(T×d₁)和输出token O∈ℝ(T×d₂),Pattention机制引入两组n个可学习参数token:K_P∈ℝ(n×d₁)表示键,V_P∈ℝ^(n×d₂)表示值。输出通过缩放点积Pattention层计算:Pattention(X,K_P,V_P) = Θ(X·K_P^T)·V_P,其中Θ是修改的softmax操作。
2、修改的Softmax操作:
标准softmax涉及对输入矩阵应用逐元素指数,然后沿token维度进行L1归一化。TokenFormer的修改softmax用沿token维度的L2归一化替代,然后对每个元素应用GeLU激活。这种设计改善了Pattention层的梯度稳定性,平滑了注意力分数,相比标准softmax操作获得更好的性能。具体实现为:先进行L2归一化norm_outputs = inputs / ||inputs||₂ * √dim,然后应用GeLU激活。
3、全注意力架构设计:
TokenFormer遵循pre-norm transformer设计,计算过程为:X_inter = X_in + MHA(LN(X_in)),X_out = X_inter + FFN(LN(X_inter))。在多头自注意力块中,所有线性投影都被Pattention层替换:Q = Pattention(X,K_PQ,V_PQ),K = Pattention(X,K_PK,V_PK),V = Pattention(X,K_PV,V_PV)。然后进行标准的token-token注意力:X_att = softmax(QK^T/√d)V,最后输出投影:O_att = Pattention(X_att,K_PO,V_PO)。
图2. Pattention层的详细结构和参数token机制
3、代码
完整代码见gitcode地址:https://gitcode.com/2301_80107842/research
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class Pattention(nn.Module):
"""Token-参数注意力层"""
def __init__(self, input_channels, output_channels, param_token_num,
norm_activation_type='gelu_l2_norm'):
super().__init__()
self.param_token_num = param_token_num
self.param_key_dim = input_channels
self.param_value_dim = output_channels
self.norm_activation_type = norm_activation_type
# 可学习的参数token
self.key_param_tokens = nn.Parameter(
torch.randn(self.param_token_num, self.param_key_dim)
)
self.value_param_tokens = nn.Parameter(
torch.randn(self.param_token_num, self.param_value_dim)
)
# 初始化参数
self._init_parameters()
def _init_parameters(self):
"""初始化参数token"""
nn.init.xavier_normal_(self.key_param_tokens)
nn.init.xavier_normal_(self.value_param_tokens)
def nonlinear_norm_func(self, inputs, normalize_type, dim=-1):
"""修改的softmax操作"""
if normalize_type == 'softmax':
# 标准softmax:exp + L1归一化
nonlinear_outputs = torch.exp(inputs)
norm_outputs = nonlinear_outputs / torch.norm(
nonlinear_outputs, p=1, dim=dim, keepdim=True
) * inputs.shape[dim]
outputs = norm_outputs
elif normalize_type == 'gelu_l2_norm':
# GeLU + L2归一化
nonlinear_outputs = F.gelu(inputs)
norm_outputs = nonlinear_outputs / torch.norm(
nonlinear_outputs, p=2, dim=dim, keepdim=True
) * math.sqrt(nonlinear_outputs.shape[dim])
outputs = norm_outputs
elif normalize_type == 'l2_norm_gelu':
# L2归一化 + GeLU
norm_outputs = inputs / torch.norm(
inputs, p=2, dim=dim, keepdim=True
) * math.sqrt(inputs.shape[dim])
nonlinear_outputs = F.gelu(norm_outputs)
outputs = nonlinear_outputs
else:
raise NotImplementedError(f"Unknown normalize_type: {normalize_type}")
return outputs
def forward(self, inputs, dropout_p=0.0, attn_mask=None, scale=None):
"""
前向传播
Args:
inputs: 输入token (B, L, input_channels)
dropout_p: dropout概率
attn_mask: 注意力掩码
scale: 缩放因子
Returns:
输出token (B, L, output_channels)
"""
query = inputs # (B, L, input_channels)
key = self.key_param_tokens # (param_token_num, input_channels)
value = self.value_param_tokens # (param_token_num, output_channels)
L, S = query.size(-2), key.size(-2)
scale_factor = 1 if scale is None else scale
# 创建注意力偏置
attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), 0)
else:
raise NotImplementedError("Only boolean attention mask is supported")
# 计算注意力权重
attn_weight = query @ key.transpose(-2, -1) * scale_factor # (B, L, S)
attn_weight *= attn_bias
# 应用修改的softmax
attn_weight = self.nonlinear_norm_func(attn_weight, self.norm_activation_type, dim=-1)
attn_weight = F.dropout(attn_weight, dropout_p, training=self.training)
# 计算输出
output = attn_weight @ value # (B, L, output_channels)
return output
class TokenFormerSelfAttention(nn.Module):
"""TokenFormer自注意力层"""
def __init__(self, hidden_size, num_attention_heads, param_token_num=None):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.param_token_num = param_token_num or hidden_size
# 用Pattention替换QKV线性投影
self.query_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)
self.key_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)
self.value_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)
# 输出投影也用Pattention
self.output_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)
self.scale = self.head_dim ** -0.5
def forward(self, hidden_states, attention_mask=None):
"""
前向传播
Args:
hidden_states: 输入隐藏状态 (B, L, hidden_size)
attention_mask: 注意力掩码
Returns:
输出隐藏状态 (B, L, hidden_size)
"""
batch_size, seq_len, _ = hidden_states.shape
# 通过Pattention生成QKV
query = self.query_pattention(hidden_states) # (B, L, hidden_size)
key = self.key_pattention(hidden_states) # (B, L, hidden_size)
value = self.value_pattention(hidden_states) # (B, L, hidden_size)
# 重塑为多头格式
query = query.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
# 计算注意力分数
attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
# 应用注意力掩码
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
# 应用softmax
attn_probs = F.softmax(attn_scores, dim=-1)
# 计算注意力输出
attn_output = torch.matmul(attn_probs, value)
# 重塑回原始格式
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.hidden_size
)
# 通过输出Pattention
output = self.output_pattention(attn_output)
return output
class TokenFormerFFN(nn.Module):
"""TokenFormer前馈网络"""
def __init__(self, hidden_size, intermediate_size=None, param_token_num=None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size or 4 * hidden_size
self.param_token_num = param_token_num or 4 * hidden_size
# 用单个Pattention层替换传统的两层MLP
self.ffn_pattention = Pattention(
hidden_size, hidden_size, self.param_token_num
)
def forward(self, hidden_states):
"""前向传播"""
return self.ffn_pattention(hidden_states)
class TokenFormerLayer(nn.Module):
"""TokenFormer层"""
def __init__(self, hidden_size, num_attention_heads, intermediate_size=None,
param_token_num=None, layer_norm_eps=1e-5):
super().__init__()
self.hidden_size = hidden_size
# 层归一化
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
# 自注意力
self.self_attention = TokenFormerSelfAttention(
hidden_size, num_attention_heads, param_token_num
)
# 前馈网络
self.mlp = TokenFormerFFN(hidden_size, intermediate_size, param_token_num)
def forward(self, hidden_states, attention_mask=None):
"""前向传播"""
# 自注意力分支
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output = self.self_attention(hidden_states, attention_mask)
hidden_states = residual + attn_output
# MLP分支
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = residual + mlp_output
return hidden_states
class TokenFormer(nn.Module):
"""TokenFormer模型"""
def __init__(self, vocab_size, hidden_size=768, num_hidden_layers=12,
num_attention_heads=12, intermediate_size=None,
param_token_num=None, max_position_embeddings=1024,
layer_norm_eps=1e-5, pad_token_id=0):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.param_token_num = param_token_num or hidden_size
# 嵌入层
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
# TokenFormer层
self.layers = nn.ModuleList([
TokenFormerLayer(
hidden_size, num_attention_heads, intermediate_size,
param_token_num, layer_norm_eps
) for _ in range(num_hidden_layers)
])
# 最终层归一化
self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
# 输出头
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
# 初始化权重
self.apply(self._init_weights)
def _init_weights(self, module):
"""初始化权重"""
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward(self, input_ids, attention_mask=None, position_ids=None):
"""
前向传播
Args:
input_ids: 输入token ID (B, L)
attention_mask: 注意力掩码 (B, L)
position_ids: 位置ID (B, L)
Returns:
logits: 输出logits (B, L, vocab_size)
"""
batch_size, seq_len = input_ids.shape
# 位置ID
if position_ids is None:
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# 嵌入
word_embeds = self.word_embeddings(input_ids)
position_embeds = self.position_embeddings(position_ids)
hidden_states = word_embeds + position_embeds
# 注意力掩码处理
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
# 通过TokenFormer层
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# 最终层归一化
hidden_states = self.final_layer_norm(hidden_states)
# 输出投影
logits = self.lm_head(hidden_states)
return logits
def scale_model(self, new_param_token_num):
"""
扩展模型参数
Args:
new_param_token_num: 新的参数token数量
"""
for layer in self.layers:
# 扩展自注意力层的参数token
self._expand_pattention(layer.self_attention.query_pattention, new_param_token_num)
self._expand_pattention(layer.self_attention.key_pattention, new_param_token_num)
self._expand_pattention(layer.self_attention.value_pattention, new_param_token_num)
self._expand_pattention(layer.self_attention.output_pattention, new_param_token_num)
# 扩展FFN层的参数token
self._expand_pattention(layer.mlp.ffn_pattention, new_param_token_num)
self.param_token_num = new_param_token_num
def _expand_pattention(self, pattention_layer, new_param_token_num):
"""扩展Pattention层的参数token"""
old_num = pattention_layer.param_token_num
if new_param_token_num <= old_num:
return
# 计算需要添加的参数数量
add_num = new_param_token_num - old_num
# 扩展key参数token
old_key = pattention_layer.key_param_tokens.data
new_key = torch.randn(add_num, pattention_layer.param_key_dim,
device=old_key.device, dtype=old_key.dtype)
nn.init.xavier_normal_(new_key)
expanded_key = torch.cat([old_key, new_key], dim=0)
pattention_layer.key_param_tokens = nn.Parameter(expanded_key)
# 扩展value参数token
old_value = pattention_layer.value_param_tokens.data
new_value = torch.randn(add_num, pattention_layer.param_value_dim,
device=old_value.device, dtype=old_value.dtype)
nn.init.xavier_normal_(new_value)
expanded_value = torch.cat([old_value, new_value], dim=0)
pattention_layer.value_param_tokens = nn.Parameter(expanded_value)
# 更新参数token数量
pattention_layer.param_token_num = new_param_token_num
# 测试代码
if __name__ == '__main__':
# 创建测试数据
batch_size = 2
seq_len = 128
vocab_size = 1000
# 创建TokenFormer模型
model = TokenFormer(
vocab_size=vocab_size,
hidden_size=384,
num_hidden_layers=6,
num_attention_heads=6,
param_token_num=384
)
# 创建输入数据
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
attention_mask = torch.ones(batch_size, seq_len)
# 前向传播
logits = model(input_ids, attention_mask)
# 测试模型扩展
print("扩展前参数数量:", sum(p.numel() for p in model.parameters()) / 1e6, 'M')
model.scale_model(new_param_token_num=512)
print("扩展后参数数量:", sum(p.numel() for p in model.parameters()) / 1e6, 'M')
# 扩展后前向传播
logits_scaled = model(input_ids, attention_mask)
# 打印结果
print('输入尺寸:', input_ids.size())
print('输出logits尺寸:', logits.size())
print('扩展后logits尺寸:', logits_scaled.size())