告别千亿炼丹!我把Transformer参数做成了“乐高”,模型升级成本降90%!

TokenFormer模块

论文《TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters》
论文地址: https://arxiv.org/abs/2410.23168
发表期刊: ICLR
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
详细代码见文章最后

1、整体框架

图1. TokenFormer整体框架

TokenFormer是一个原生可扩展的Transformer架构,通过将模型参数视为token来彻底重新思考Transformer的扩展方式。该模块不仅在输入token之间使用注意力机制进行计算,还在token与模型参数之间进行交互,从而增强架构灵活性。TokenFormer用token-参数注意力层(Pattention)替换所有线性投影,其中输入token作为查询,模型参数作为键和值。这种重新表述允许渐进式高效扩展,无需从头重新训练。该模块通过增量添加新的键值参数对,可以从124M扩展到1.4B参数,实现与从头训练的Transformer相当的性能,同时大大降低训练成本。实验表明,TokenFormer在语言建模和视觉任务上都达到了竞争性能能,为大规模模型训练提供了新的范式。

2、机制

1、Token-参数注意力(Pattention)层

TokenFormer的核心创新是Pattention层,它引入了一组可训练的token作为模型参数,然后使用交叉注意力管理输入token与这些参数token之间的交互。给定输入token I∈ℝ^(T×d₁)和输出token O∈ℝ(T×d₂),Pattention机制引入两组n个可学习参数token:K_P∈ℝ(n×d₁)表示键,V_P∈ℝ^(n×d₂)表示值。输出通过缩放点积Pattention层计算:Pattention(X,K_P,V_P) = Θ(X·K_P^T)·V_P,其中Θ是修改的softmax操作。

2、修改的Softmax操作

标准softmax涉及对输入矩阵应用逐元素指数,然后沿token维度进行L1归一化。TokenFormer的修改softmax用沿token维度的L2归一化替代,然后对每个元素应用GeLU激活。这种设计改善了Pattention层的梯度稳定性,平滑了注意力分数,相比标准softmax操作获得更好的性能。具体实现为:先进行L2归一化norm_outputs = inputs / ||inputs||₂ * √dim,然后应用GeLU激活。

3、全注意力架构设计

TokenFormer遵循pre-norm transformer设计,计算过程为:X_inter = X_in + MHA(LN(X_in)),X_out = X_inter + FFN(LN(X_inter))。在多头自注意力块中,所有线性投影都被Pattention层替换:Q = Pattention(X,K_PQ,V_PQ),K = Pattention(X,K_PK,V_PK),V = Pattention(X,K_PV,V_PV)。然后进行标准的token-token注意力:X_att = softmax(QK^T/√d)V,最后输出投影:O_att = Pattention(X_att,K_PO,V_PO)。

图2. Pattention层的详细结构和参数token机制

3、代码

完整代码见gitcode地址:https://gitcode.com/2301_80107842/research

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple

class Pattention(nn.Module):
    """Token-参数注意力层"""
    def __init__(self, input_channels, output_channels, param_token_num,
                 norm_activation_type='gelu_l2_norm'):
        super().__init__()
        self.param_token_num = param_token_num
        self.param_key_dim = input_channels
        self.param_value_dim = output_channels
        self.norm_activation_type = norm_activation_type

        # 可学习的参数token
        self.key_param_tokens = nn.Parameter(
            torch.randn(self.param_token_num, self.param_key_dim)
        )
        self.value_param_tokens = nn.Parameter(
            torch.randn(self.param_token_num, self.param_value_dim)
        )

        # 初始化参数
        self._init_parameters()

    def _init_parameters(self):
        """初始化参数token"""
        nn.init.xavier_normal_(self.key_param_tokens)
        nn.init.xavier_normal_(self.value_param_tokens)

    def nonlinear_norm_func(self, inputs, normalize_type, dim=-1):
        """修改的softmax操作"""
        if normalize_type == 'softmax':
            # 标准softmax:exp + L1归一化
            nonlinear_outputs = torch.exp(inputs)
            norm_outputs = nonlinear_outputs / torch.norm(
                nonlinear_outputs, p=1, dim=dim, keepdim=True
            ) * inputs.shape[dim]
            outputs = norm_outputs
        elif normalize_type == 'gelu_l2_norm':
            # GeLU + L2归一化
            nonlinear_outputs = F.gelu(inputs)
            norm_outputs = nonlinear_outputs / torch.norm(
                nonlinear_outputs, p=2, dim=dim, keepdim=True
            ) * math.sqrt(nonlinear_outputs.shape[dim])
            outputs = norm_outputs
        elif normalize_type == 'l2_norm_gelu':
            # L2归一化 + GeLU
            norm_outputs = inputs / torch.norm(
                inputs, p=2, dim=dim, keepdim=True
            ) * math.sqrt(inputs.shape[dim])
            nonlinear_outputs = F.gelu(norm_outputs)
            outputs = nonlinear_outputs
        else:
            raise NotImplementedError(f"Unknown normalize_type: {normalize_type}")

        return outputs

    def forward(self, inputs, dropout_p=0.0, attn_mask=None, scale=None):
        """
        前向传播
        Args:
            inputs: 输入token (B, L, input_channels)
            dropout_p: dropout概率
            attn_mask: 注意力掩码
            scale: 缩放因子
        Returns:
            输出token (B, L, output_channels)
        """
        query = inputs  # (B, L, input_channels)
        key = self.key_param_tokens  # (param_token_num, input_channels)
        value = self.value_param_tokens  # (param_token_num, output_channels)

        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 if scale is None else scale

        # 创建注意力偏置
        attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), 0)
            else:
                raise NotImplementedError("Only boolean attention mask is supported")

        # 计算注意力权重
        attn_weight = query @ key.transpose(-2, -1) * scale_factor  # (B, L, S)
        attn_weight *= attn_bias

        # 应用修改的softmax
        attn_weight = self.nonlinear_norm_func(attn_weight, self.norm_activation_type, dim=-1)
        attn_weight = F.dropout(attn_weight, dropout_p, training=self.training)

        # 计算输出
        output = attn_weight @ value  # (B, L, output_channels)

        return output

class TokenFormerSelfAttention(nn.Module):
    """TokenFormer自注意力层"""
    def __init__(self, hidden_size, num_attention_heads, param_token_num=None):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        self.param_token_num = param_token_num or hidden_size

        # 用Pattention替换QKV线性投影
        self.query_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)
        self.key_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)
        self.value_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)

        # 输出投影也用Pattention
        self.output_pattention = Pattention(hidden_size, hidden_size, self.param_token_num)

        self.scale = self.head_dim ** -0.5

    def forward(self, hidden_states, attention_mask=None):
        """
        前向传播
        Args:
            hidden_states: 输入隐藏状态 (B, L, hidden_size)
            attention_mask: 注意力掩码
        Returns:
            输出隐藏状态 (B, L, hidden_size)
        """
        batch_size, seq_len, _ = hidden_states.shape

        # 通过Pattention生成QKV
        query = self.query_pattention(hidden_states)  # (B, L, hidden_size)
        key = self.key_pattention(hidden_states)      # (B, L, hidden_size)
        value = self.value_pattention(hidden_states)  # (B, L, hidden_size)

        # 重塑为多头格式
        query = query.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale

        # 应用注意力掩码
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask

        # 应用softmax
        attn_probs = F.softmax(attn_scores, dim=-1)

        # 计算注意力输出
        attn_output = torch.matmul(attn_probs, value)

        # 重塑回原始格式
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.hidden_size
        )

        # 通过输出Pattention
        output = self.output_pattention(attn_output)

        return output

class TokenFormerFFN(nn.Module):
    """TokenFormer前馈网络"""
    def __init__(self, hidden_size, intermediate_size=None, param_token_num=None):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size or 4 * hidden_size
        self.param_token_num = param_token_num or 4 * hidden_size

        # 用单个Pattention层替换传统的两层MLP
        self.ffn_pattention = Pattention(
            hidden_size, hidden_size, self.param_token_num
        )

    def forward(self, hidden_states):
        """前向传播"""
        return self.ffn_pattention(hidden_states)

class TokenFormerLayer(nn.Module):
    """TokenFormer层"""
    def __init__(self, hidden_size, num_attention_heads, intermediate_size=None,
                 param_token_num=None, layer_norm_eps=1e-5):
        super().__init__()
        self.hidden_size = hidden_size

        # 层归一化
        self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

        # 自注意力
        self.self_attention = TokenFormerSelfAttention(
            hidden_size, num_attention_heads, param_token_num
        )

        # 前馈网络
        self.mlp = TokenFormerFFN(hidden_size, intermediate_size, param_token_num)

    def forward(self, hidden_states, attention_mask=None):
        """前向传播"""
        # 自注意力分支
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        attn_output = self.self_attention(hidden_states, attention_mask)
        hidden_states = residual + attn_output

        # MLP分支
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        mlp_output = self.mlp(hidden_states)
        hidden_states = residual + mlp_output

        return hidden_states

class TokenFormer(nn.Module):
    """TokenFormer模型"""
    def __init__(self, vocab_size, hidden_size=768, num_hidden_layers=12,
                 num_attention_heads=12, intermediate_size=None,
                 param_token_num=None, max_position_embeddings=1024,
                 layer_norm_eps=1e-5, pad_token_id=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.param_token_num = param_token_num or hidden_size

        # 嵌入层
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)

        # TokenFormer层
        self.layers = nn.ModuleList([
            TokenFormerLayer(
                hidden_size, num_attention_heads, intermediate_size,
                param_token_num, layer_norm_eps
            ) for _ in range(num_hidden_layers)
        ])

        # 最终层归一化
        self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

        # 输出头
        self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)

        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """初始化权重"""
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)

    def forward(self, input_ids, attention_mask=None, position_ids=None):
        """
        前向传播
        Args:
            input_ids: 输入token ID (B, L)
            attention_mask: 注意力掩码 (B, L)
            position_ids: 位置ID (B, L)
        Returns:
            logits: 输出logits (B, L, vocab_size)
        """
        batch_size, seq_len = input_ids.shape

        # 位置ID
        if position_ids is None:
            position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)

        # 嵌入
        word_embeds = self.word_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        hidden_states = word_embeds + position_embeds

        # 注意力掩码处理
        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = (1.0 - attention_mask) * -10000.0

        # 通过TokenFormer层
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)

        # 最终层归一化
        hidden_states = self.final_layer_norm(hidden_states)

        # 输出投影
        logits = self.lm_head(hidden_states)

        return logits

    def scale_model(self, new_param_token_num):
        """
        扩展模型参数
        Args:
            new_param_token_num: 新的参数token数量
        """
        for layer in self.layers:
            # 扩展自注意力层的参数token
            self._expand_pattention(layer.self_attention.query_pattention, new_param_token_num)
            self._expand_pattention(layer.self_attention.key_pattention, new_param_token_num)
            self._expand_pattention(layer.self_attention.value_pattention, new_param_token_num)
            self._expand_pattention(layer.self_attention.output_pattention, new_param_token_num)

            # 扩展FFN层的参数token
            self._expand_pattention(layer.mlp.ffn_pattention, new_param_token_num)

        self.param_token_num = new_param_token_num

    def _expand_pattention(self, pattention_layer, new_param_token_num):
        """扩展Pattention层的参数token"""
        old_num = pattention_layer.param_token_num
        if new_param_token_num <= old_num:
            return

        # 计算需要添加的参数数量
        add_num = new_param_token_num - old_num

        # 扩展key参数token
        old_key = pattention_layer.key_param_tokens.data
        new_key = torch.randn(add_num, pattention_layer.param_key_dim,
                             device=old_key.device, dtype=old_key.dtype)
        nn.init.xavier_normal_(new_key)
        expanded_key = torch.cat([old_key, new_key], dim=0)
        pattention_layer.key_param_tokens = nn.Parameter(expanded_key)

        # 扩展value参数token
        old_value = pattention_layer.value_param_tokens.data
        new_value = torch.randn(add_num, pattention_layer.param_value_dim,
                               device=old_value.device, dtype=old_value.dtype)
        nn.init.xavier_normal_(new_value)
        expanded_value = torch.cat([old_value, new_value], dim=0)
        pattention_layer.value_param_tokens = nn.Parameter(expanded_value)

        # 更新参数token数量
        pattention_layer.param_token_num = new_param_token_num

# 测试代码
if __name__ == '__main__':
    # 创建测试数据
    batch_size = 2
    seq_len = 128
    vocab_size = 1000

    # 创建TokenFormer模型
    model = TokenFormer(
        vocab_size=vocab_size,
        hidden_size=384,
        num_hidden_layers=6,
        num_attention_heads=6,
        param_token_num=384
    )

    # 创建输入数据
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    attention_mask = torch.ones(batch_size, seq_len)

    # 前向传播
    logits = model(input_ids, attention_mask)

    # 测试模型扩展
    print("扩展前参数数量:", sum(p.numel() for p in model.parameters()) / 1e6, 'M')
    model.scale_model(new_param_token_num=512)
    print("扩展后参数数量:", sum(p.numel() for p in model.parameters()) / 1e6, 'M')

    # 扩展后前向传播
    logits_scaled = model(input_ids, attention_mask)

    # 打印结果
    print('输入尺寸:', input_ids.size())
    print('输出logits尺寸:', logits.size())
    print('扩展后logits尺寸:', logits_scaled.size())

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值