还在为视频发愁?3分钟学会用AI做大片,这个开源工具让文案秒变视频!

CogVideoX模块

论文《CogVideoX: Text-to-Video Diffusion Models with an Expert Transformer》
论文地址: https://github.com/THUDM/CogVideo
发表期刊: ICLR
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
详细代码见文章最后

1、作用

CogVideoX是一个大规模文本到视频生成模块,基于扩散Transformer架构,能够生成10秒连续视频,帧率16fps,分辨率768×1360像素,与文本提示完美对齐。该模块专门解决以往视频生成模型运动有限、时长较短、难以基于文本生成连贯叙事视频的问题。CogVideoX通过3D变分自编码器(VAE)在时空维度压缩视频,提升压缩率和视频保真度;采用专家Transformer和专家自适应LayerNorm促进文本-视频深度融合;运用渐进训练和多分辨率帧打包技术,擅长生成连贯、长时长、多样形状和动态运动的视频。实验表明,CogVideoX-5B在自动化基准测试和人工评估中均达到SOTA性能,在多个维度上超越知名视频模型。

图1. CogVideoX的整体架构

2、机制

1、3D因果变分自编码器(3D Causal VAE)

3D VAE采用时间因果卷积设计,在时空维度同时压缩视频数据,实现8×8×4的压缩比。编码器和解码器由对称排列的阶段组成,通过ResNet块堆叠阶段交替执行2×下采样和上采样。时间因果卷积将所有填充放置在卷积空间的开始位置,确保未来信息不会影响当前或过去的预测。为处理长时长视频的GPU内存使用问题,在时间维度应用上下文并行进行3D卷积,将计算分布到多个设备上。

图2. 3D VAE的结构设计和时间因果卷积的上下文并行实现

2、专家Transformer架构(Expert Transformer)

专家Transformer包含专家自适应LayerNorm来独立处理文本和视觉两种模态。Vision Expert AdaLN和Text Expert AdaLN分别对视觉隐藏状态和文本隐藏状态应用扩散过程的时间步t调制,促进两种模态特征空间的对齐同时最小化额外参数。采用3D-RoPE位置编码,将原始RoPE扩展到3D,每个视频张量中的潜在表示用3D坐标(x,y,t)表示,独立对坐标的每个维度应用1D-RoPE。

图3. 3D全注意力机制

3、3D全注意力机制(3D Full Attention)

不同于以往工作采用分离的空间和时间注意力,CogVideoX提出3D文本-视频混合注意力机制。分离注意力方法需要大量隐式视觉信息传输,显著增加学习复杂性,难以维持大运动物体的一致性。3D全注意力机制不仅获得更好结果,还能轻松适应各种并行加速方法,借鉴长上下文训练在LLM中的成功和FlashAttention的效率。

图4. 训练机制

3、代码

完整代码见gitcode地址:https://gitcode.com/2301_80107842/research

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from typing import Optional, Tuple, Dict, Any

def modulate(x, shift, scale):
    """AdaLN调制函数"""
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

def rotate_half(x):
    """RoPE旋转函数"""
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")

class ExpertAdaptiveLayerNorm(nn.Module):
    """专家自适应LayerNorm"""
    def __init__(self, hidden_size, time_embed_dim, num_layers):
        super().__init__()
        self.hidden_size = hidden_size

        # 为每一层创建AdaLN调制模块
        self.vision_adaLN_modulations = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_embed_dim, 6 * hidden_size, bias=True)
            ) for _ in range(num_layers)
        ])

        self.text_adaLN_modulations = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_embed_dim, 6 * hidden_size, bias=True)
            ) for _ in range(num_layers)
        ])

    def forward(self, hidden_states, emb, text_length, layer_id):
        """
        前向传播
        Args:
            hidden_states: 输入隐藏状态 (B, seq_len, hidden_size)
            emb: 时间嵌入 (B, time_embed_dim)
            text_length: 文本序列长度
            layer_id: 当前层ID
        Returns:
            调制后的隐藏状态
        """
        # 分离文本和视觉部分
        text_hidden_states = hidden_states[:, :text_length]  # (B, text_len, hidden_size)
        img_hidden_states = hidden_states[:, text_length:]   # (B, img_len, hidden_size)

        # 获取视觉调制参数
        vision_modulation = self.vision_adaLN_modulations[layer_id](emb)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = vision_modulation.chunk(6, dim=1)

        # 获取文本调制参数
        text_modulation = self.text_adaLN_modulations[layer_id](emb)
        text_shift_msa, text_scale_msa, text_gate_msa, text_shift_mlp, text_scale_mlp, text_gate_mlp = text_modulation.chunk(6, dim=1)

        return {
            'vision_params': (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp),
            'text_params': (text_shift_msa, text_scale_msa, text_gate_msa, text_shift_mlp, text_scale_mlp, text_gate_mlp),
            'text_hidden_states': text_hidden_states,
            'img_hidden_states': img_hidden_states
        }

class Rotary3DPositionEmbedding(nn.Module):
    """3D旋转位置编码"""
    def __init__(self, hidden_size_head, height, width, compressed_num_frames, theta=10000):
        super().__init__()

        # 分配维度:时间1/4,高度3/8,宽度3/8
        dim_t = hidden_size_head // 4
        dim_h = hidden_size_head // 8 * 3
        dim_w = hidden_size_head // 8 * 3

        # 计算频率
        freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[:(dim_t // 2)].float() / dim_t))
        freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[:(dim_h // 2)].float() / dim_h))
        freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[:(dim_w // 2)].float() / dim_w))

        # 创建网格
        grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
        grid_h = torch.arange(height, dtype=torch.float32)
        grid_w = torch.arange(width, dtype=torch.float32)

        # 计算位置编码
        freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
        freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
        freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)

        freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
        freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
        freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)

        # 广播拼接
        freqs = torch.cat([
            freqs_t[:, None, None, :].expand(-1, height, width, -1),
            freqs_h[None, :, None, :].expand(compressed_num_frames, -1, width, -1),
            freqs_w[None, None, :, :].expand(compressed_num_frames, height, -1, -1)
        ], dim=-1)

        freqs = freqs.contiguous()
        self.register_buffer('freqs_sin', freqs.sin())
        self.register_buffer('freqs_cos', freqs.cos())

    def forward(self, q, k, seq_len):
        """
        应用3D RoPE
        Args:
            q: 查询张量 (B, num_heads, seq_len, head_dim)
            k: 键张量 (B, num_heads, seq_len, head_dim)
            seq_len: 序列长度
        Returns:
            应用RoPE后的q, k
        """
        # 获取对应长度的位置编码
        cos = self.freqs_cos.flatten(0, 2)[:seq_len]  # (seq_len, head_dim)
        sin = self.freqs_sin.flatten(0, 2)[:seq_len]  # (seq_len, head_dim)

        # 应用旋转
        q_rot = q * cos + rotate_half(q) * sin
        k_rot = k * cos + rotate_half(k) * sin

        return q_rot, k_rot

class CausalConv3d(nn.Module):
    """3D因果卷积"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride

        # 时间维度的因果填充
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size, kernel_size)
        if isinstance(padding, int):
            padding = (padding, padding, padding)

        # 时间维度使用因果填充(只在前面填充)
        self.time_pad = kernel_size[0] - 1
        self.spatial_pad = (padding[1], padding[2])

        self.conv = nn.Conv3d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=(0, padding[1], padding[2])
        )

    def forward(self, x):
        """
        前向传播
        Args:
            x: 输入张量 (B, C, T, H, W)
        Returns:
            输出张量 (B, C, T', H', W')
        """
        # 时间维度因果填充
        if self.time_pad > 0:
            x = F.pad(x, (0, 0, 0, 0, self.time_pad, 0))

        return self.conv(x)

class VAE3D(nn.Module):
    """3D变分自编码器"""
    def __init__(self, in_channels=3, out_channels=3, hidden_channels=128,
                 latent_channels=16, compression_ratio=(4, 8, 8)):
        super().__init__()
        self.compression_ratio = compression_ratio

        # 编码器
        self.encoder = nn.Sequential(
            CausalConv3d(in_channels, hidden_channels, 3, padding=1),
            nn.GroupNorm(8, hidden_channels),
            nn.SiLU(),
            CausalConv3d(hidden_channels, hidden_channels, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.GroupNorm(8, hidden_channels),
            nn.SiLU(),
            CausalConv3d(hidden_channels, hidden_channels, (2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1)),
            nn.GroupNorm(8, hidden_channels),
            nn.SiLU(),
            CausalConv3d(hidden_channels, latent_channels * 2, 3, padding=1),  # 均值和方差
        )

        # 解码器
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(latent_channels, hidden_channels, 3, padding=1),
            nn.GroupNorm(8, hidden_channels),
            nn.SiLU(),
            nn.ConvTranspose3d(hidden_channels, hidden_channels, (2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1)),
            nn.GroupNorm(8, hidden_channels),
            nn.SiLU(),
            nn.ConvTranspose3d(hidden_channels, hidden_channels, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.GroupNorm(8, hidden_channels),
            nn.SiLU(),
            nn.ConvTranspose3d(hidden_channels, out_channels, 3, padding=1),
        )

    def encode(self, x):
        """编码"""
        h = self.encoder(x)
        mean, logvar = h.chunk(2, dim=1)
        return mean, logvar

    def decode(self, z):
        """解码"""
        return self.decoder(z)

    def reparameterize(self, mean, logvar):
        """重参数化技巧"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        """前向传播"""
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        recon = self.decode(z)
        return recon, mean, logvar

class CogVideoXBlock(nn.Module):
    """CogVideoX Transformer块"""
    def __init__(self, hidden_size, num_heads, time_embed_dim, mlp_ratio=4.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads

        # 层归一化
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)

        # 多头注意力
        self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
        self.proj = nn.Linear(hidden_size, hidden_size)

        # MLP
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, hidden_size)
        )

        # 专家AdaLN
        self.expert_adaln = ExpertAdaptiveLayerNorm(hidden_size, time_embed_dim, 1)

        # 3D RoPE
        self.rope_3d = Rotary3DPositionEmbedding(head_dim, 32, 32, 16)  # 示例参数

    def forward(self, x, emb, text_length, layer_id=0):
        """
        前向传播
        Args:
            x: 输入张量 (B, seq_len, hidden_size)
            emb: 时间嵌入 (B, time_embed_dim)
            text_length: 文本序列长度
            layer_id: 层ID
        Returns:
            输出张量 (B, seq_len, hidden_size)
        """
        # 专家AdaLN调制
        adaln_output = self.expert_adaln(x, emb, text_length, layer_id)
        vision_params = adaln_output['vision_params']
        text_params = adaln_output['text_params']
        text_hidden = adaln_output['text_hidden_states']
        img_hidden = adaln_output['img_hidden_states']

        # 注意力分支
        shift_msa, scale_msa, gate_msa = vision_params[:3]
        text_shift_msa, text_scale_msa, text_gate_msa = text_params[:3]

        # 应用AdaLN到注意力输入
        img_attn_input = modulate(self.norm1(img_hidden), shift_msa, scale_msa)
        text_attn_input = modulate(self.norm1(text_hidden), text_shift_msa, text_scale_msa)
        attn_input = torch.cat([text_attn_input, img_attn_input], dim=1)

        # 多头注意力
        B, seq_len, _ = attn_input.shape
        qkv = self.qkv(attn_input).reshape(B, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 应用3D RoPE(仅对视觉部分)
        if seq_len > text_length:
            q_text, q_img = q[:, :, :text_length], q[:, :, text_length:]
            k_text, k_img = k[:, :, :text_length], k[:, :, text_length:]
            q_img, k_img = self.rope_3d(q_img, k_img, seq_len - text_length)
            q = torch.cat([q_text, q_img], dim=2)
            k = torch.cat([k_text, k_img], dim=2)

        # 计算注意力
        scale = (self.hidden_size // self.num_heads) ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = F.softmax(attn, dim=-1)
        out = attn @ v

        out = out.transpose(1, 2).reshape(B, seq_len, self.hidden_size)
        out = self.proj(out)

        # 分离输出并应用门控
        text_attn_out = out[:, :text_length]
        img_attn_out = out[:, text_length:]

        text_hidden = text_hidden + text_gate_msa.unsqueeze(1) * text_attn_out
        img_hidden = img_hidden + gate_msa.unsqueeze(1) * img_attn_out

        # MLP分支
        shift_mlp, scale_mlp, gate_mlp = vision_params[3:]
        text_shift_mlp, text_scale_mlp, text_gate_mlp = text_params[3:]

        img_mlp_input = modulate(self.norm2(img_hidden), shift_mlp, scale_mlp)
        text_mlp_input = modulate(self.norm2(text_hidden), text_shift_mlp, text_scale_mlp)
        mlp_input = torch.cat([text_mlp_input, img_mlp_input], dim=1)

        mlp_out = self.mlp(mlp_input)
        text_mlp_out = mlp_out[:, :text_length]
        img_mlp_out = mlp_out[:, text_length:]

        text_hidden = text_hidden + text_gate_mlp.unsqueeze(1) * text_mlp_out
        img_hidden = img_hidden + gate_mlp.unsqueeze(1) * img_mlp_out

        return torch.cat([text_hidden, img_hidden], dim=1)

class CogVideoX(nn.Module):
    """CogVideoX主模型"""
    def __init__(self, hidden_size=1152, num_layers=28, num_heads=16,
                 text_hidden_size=4096, time_embed_dim=512):
        super().__init__()
        self.hidden_size = hidden_size
        self.time_embed_dim = time_embed_dim

        # 时间嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(hidden_size, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim)
        )

        # 文本投影
        self.text_proj = nn.Linear(text_hidden_size, hidden_size)

        # Transformer层
        self.layers = nn.ModuleList([
            CogVideoXBlock(hidden_size, num_heads, time_embed_dim)
            for _ in range(num_layers)
        ])

        # 输出层
        self.norm_final = nn.LayerNorm(hidden_size)
        self.proj_out = nn.Linear(hidden_size, hidden_size)

    def forward(self, video_latents, text_embeds, timesteps):
        """
        前向传播
        Args:
            video_latents: 视频潜在表示 (B, seq_len, hidden_size)
            text_embeds: 文本嵌入 (B, text_len, text_hidden_size)
            timesteps: 时间步 (B,)
        Returns:
            输出潜在表示 (B, seq_len, hidden_size)
        """
        # 时间嵌入
        t_emb = self.time_embed(timestep_embedding(timesteps, self.hidden_size))

        # 文本投影
        text_embeds = self.text_proj(text_embeds)
        text_length = text_embeds.shape[1]

        # 拼接文本和视频
        x = torch.cat([text_embeds, video_latents], dim=1)

        # 通过Transformer层
        for i, layer in enumerate(self.layers):
            x = layer(x, t_emb, text_length, i)

        # 输出处理(只处理视频部分)
        video_out = x[:, text_length:]
        video_out = self.norm_final(video_out)
        video_out = self.proj_out(video_out)

        return video_out

def timestep_embedding(timesteps, dim, max_period=10000):
    """时间步嵌入"""
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

# 测试代码
if __name__ == '__main__':
    # 创建测试数据
    batch_size = 2
    seq_len = 1024  # 视频序列长度
    text_len = 77   # 文本序列长度
    hidden_size = 1152
    text_hidden_size = 4096

    # 创建模型
    model = CogVideoX(hidden_size=hidden_size, num_layers=28, num_heads=16,
                     text_hidden_size=text_hidden_size, time_embed_dim=512)

    # 创建输入数据
    video_latents = torch.randn(batch_size, seq_len, hidden_size)
    text_embeds = torch.randn(batch_size, text_len, text_hidden_size)
    timesteps = torch.randint(0, 1000, (batch_size,))

    # 前向传播
    output = model(video_latents, text_embeds, timesteps)

    # 测试3D VAE
    vae = VAE3D(in_channels=3, latent_channels=16)
    video_input = torch.randn(batch_size, 3, 16, 64, 64)  # (B, C, T, H, W)
    recon, mean, logvar = vae(video_input)

    # 打印结果
    print('视频潜在表示尺寸:', video_latents.size())
    print('文本嵌入尺寸:', text_embeds.size())
    print('输出尺寸:', output.size())
    print('参数数量:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')
    print('3D VAE输入尺寸:', video_input.size())
    print('3D VAE重建尺寸:', recon.size())

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值