CogVideoX模块
论文《CogVideoX: Text-to-Video Diffusion Models with an Expert Transformer》
论文地址: https://github.com/THUDM/CogVideo
发表期刊: ICLR
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
详细代码见文章最后
1、作用
CogVideoX是一个大规模文本到视频生成模块,基于扩散Transformer架构,能够生成10秒连续视频,帧率16fps,分辨率768×1360像素,与文本提示完美对齐。该模块专门解决以往视频生成模型运动有限、时长较短、难以基于文本生成连贯叙事视频的问题。CogVideoX通过3D变分自编码器(VAE)在时空维度压缩视频,提升压缩率和视频保真度;采用专家Transformer和专家自适应LayerNorm促进文本-视频深度融合;运用渐进训练和多分辨率帧打包技术,擅长生成连贯、长时长、多样形状和动态运动的视频。实验表明,CogVideoX-5B在自动化基准测试和人工评估中均达到SOTA性能,在多个维度上超越知名视频模型。
图1. CogVideoX的整体架构
2、机制
1、3D因果变分自编码器(3D Causal VAE):
3D VAE采用时间因果卷积设计,在时空维度同时压缩视频数据,实现8×8×4的压缩比。编码器和解码器由对称排列的阶段组成,通过ResNet块堆叠阶段交替执行2×下采样和上采样。时间因果卷积将所有填充放置在卷积空间的开始位置,确保未来信息不会影响当前或过去的预测。为处理长时长视频的GPU内存使用问题,在时间维度应用上下文并行进行3D卷积,将计算分布到多个设备上。
图2. 3D VAE的结构设计和时间因果卷积的上下文并行实现
2、专家Transformer架构(Expert Transformer):
专家Transformer包含专家自适应LayerNorm来独立处理文本和视觉两种模态。Vision Expert AdaLN和Text Expert AdaLN分别对视觉隐藏状态和文本隐藏状态应用扩散过程的时间步t调制,促进两种模态特征空间的对齐同时最小化额外参数。采用3D-RoPE位置编码,将原始RoPE扩展到3D,每个视频张量中的潜在表示用3D坐标(x,y,t)表示,独立对坐标的每个维度应用1D-RoPE。
图3. 3D全注意力机制
3、3D全注意力机制(3D Full Attention):
不同于以往工作采用分离的空间和时间注意力,CogVideoX提出3D文本-视频混合注意力机制。分离注意力方法需要大量隐式视觉信息传输,显著增加学习复杂性,难以维持大运动物体的一致性。3D全注意力机制不仅获得更好结果,还能轻松适应各种并行加速方法,借鉴长上下文训练在LLM中的成功和FlashAttention的效率。
图4. 训练机制
3、代码
完整代码见gitcode地址:https://gitcode.com/2301_80107842/research
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from typing import Optional, Tuple, Dict, Any
def modulate(x, shift, scale):
"""AdaLN调制函数"""
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def rotate_half(x):
"""RoPE旋转函数"""
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class ExpertAdaptiveLayerNorm(nn.Module):
"""专家自适应LayerNorm"""
def __init__(self, hidden_size, time_embed_dim, num_layers):
super().__init__()
self.hidden_size = hidden_size
# 为每一层创建AdaLN调制模块
self.vision_adaLN_modulations = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(time_embed_dim, 6 * hidden_size, bias=True)
) for _ in range(num_layers)
])
self.text_adaLN_modulations = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(time_embed_dim, 6 * hidden_size, bias=True)
) for _ in range(num_layers)
])
def forward(self, hidden_states, emb, text_length, layer_id):
"""
前向传播
Args:
hidden_states: 输入隐藏状态 (B, seq_len, hidden_size)
emb: 时间嵌入 (B, time_embed_dim)
text_length: 文本序列长度
layer_id: 当前层ID
Returns:
调制后的隐藏状态
"""
# 分离文本和视觉部分
text_hidden_states = hidden_states[:, :text_length] # (B, text_len, hidden_size)
img_hidden_states = hidden_states[:, text_length:] # (B, img_len, hidden_size)
# 获取视觉调制参数
vision_modulation = self.vision_adaLN_modulations[layer_id](emb)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = vision_modulation.chunk(6, dim=1)
# 获取文本调制参数
text_modulation = self.text_adaLN_modulations[layer_id](emb)
text_shift_msa, text_scale_msa, text_gate_msa, text_shift_mlp, text_scale_mlp, text_gate_mlp = text_modulation.chunk(6, dim=1)
return {
'vision_params': (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp),
'text_params': (text_shift_msa, text_scale_msa, text_gate_msa, text_shift_mlp, text_scale_mlp, text_gate_mlp),
'text_hidden_states': text_hidden_states,
'img_hidden_states': img_hidden_states
}
class Rotary3DPositionEmbedding(nn.Module):
"""3D旋转位置编码"""
def __init__(self, hidden_size_head, height, width, compressed_num_frames, theta=10000):
super().__init__()
# 分配维度:时间1/4,高度3/8,宽度3/8
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
# 计算频率
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[:(dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[:(dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[:(dim_w // 2)].float() / dim_w))
# 创建网格
grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
grid_h = torch.arange(height, dtype=torch.float32)
grid_w = torch.arange(width, dtype=torch.float32)
# 计算位置编码
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
# 广播拼接
freqs = torch.cat([
freqs_t[:, None, None, :].expand(-1, height, width, -1),
freqs_h[None, :, None, :].expand(compressed_num_frames, -1, width, -1),
freqs_w[None, None, :, :].expand(compressed_num_frames, height, -1, -1)
], dim=-1)
freqs = freqs.contiguous()
self.register_buffer('freqs_sin', freqs.sin())
self.register_buffer('freqs_cos', freqs.cos())
def forward(self, q, k, seq_len):
"""
应用3D RoPE
Args:
q: 查询张量 (B, num_heads, seq_len, head_dim)
k: 键张量 (B, num_heads, seq_len, head_dim)
seq_len: 序列长度
Returns:
应用RoPE后的q, k
"""
# 获取对应长度的位置编码
cos = self.freqs_cos.flatten(0, 2)[:seq_len] # (seq_len, head_dim)
sin = self.freqs_sin.flatten(0, 2)[:seq_len] # (seq_len, head_dim)
# 应用旋转
q_rot = q * cos + rotate_half(q) * sin
k_rot = k * cos + rotate_half(k) * sin
return q_rot, k_rot
class CausalConv3d(nn.Module):
"""3D因果卷积"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
# 时间维度的因果填充
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
# 时间维度使用因果填充(只在前面填充)
self.time_pad = kernel_size[0] - 1
self.spatial_pad = (padding[1], padding[2])
self.conv = nn.Conv3d(
in_channels, out_channels, kernel_size,
stride=stride, padding=(0, padding[1], padding[2])
)
def forward(self, x):
"""
前向传播
Args:
x: 输入张量 (B, C, T, H, W)
Returns:
输出张量 (B, C, T', H', W')
"""
# 时间维度因果填充
if self.time_pad > 0:
x = F.pad(x, (0, 0, 0, 0, self.time_pad, 0))
return self.conv(x)
class VAE3D(nn.Module):
"""3D变分自编码器"""
def __init__(self, in_channels=3, out_channels=3, hidden_channels=128,
latent_channels=16, compression_ratio=(4, 8, 8)):
super().__init__()
self.compression_ratio = compression_ratio
# 编码器
self.encoder = nn.Sequential(
CausalConv3d(in_channels, hidden_channels, 3, padding=1),
nn.GroupNorm(8, hidden_channels),
nn.SiLU(),
CausalConv3d(hidden_channels, hidden_channels, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.GroupNorm(8, hidden_channels),
nn.SiLU(),
CausalConv3d(hidden_channels, hidden_channels, (2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1)),
nn.GroupNorm(8, hidden_channels),
nn.SiLU(),
CausalConv3d(hidden_channels, latent_channels * 2, 3, padding=1), # 均值和方差
)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose3d(latent_channels, hidden_channels, 3, padding=1),
nn.GroupNorm(8, hidden_channels),
nn.SiLU(),
nn.ConvTranspose3d(hidden_channels, hidden_channels, (2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1)),
nn.GroupNorm(8, hidden_channels),
nn.SiLU(),
nn.ConvTranspose3d(hidden_channels, hidden_channels, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.GroupNorm(8, hidden_channels),
nn.SiLU(),
nn.ConvTranspose3d(hidden_channels, out_channels, 3, padding=1),
)
def encode(self, x):
"""编码"""
h = self.encoder(x)
mean, logvar = h.chunk(2, dim=1)
return mean, logvar
def decode(self, z):
"""解码"""
return self.decoder(z)
def reparameterize(self, mean, logvar):
"""重参数化技巧"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
def forward(self, x):
"""前向传播"""
mean, logvar = self.encode(x)
z = self.reparameterize(mean, logvar)
recon = self.decode(z)
return recon, mean, logvar
class CogVideoXBlock(nn.Module):
"""CogVideoX Transformer块"""
def __init__(self, hidden_size, num_heads, time_embed_dim, mlp_ratio=4.0):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
# 层归一化
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
# 多头注意力
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
self.proj = nn.Linear(hidden_size, hidden_size)
# MLP
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, hidden_size)
)
# 专家AdaLN
self.expert_adaln = ExpertAdaptiveLayerNorm(hidden_size, time_embed_dim, 1)
# 3D RoPE
self.rope_3d = Rotary3DPositionEmbedding(head_dim, 32, 32, 16) # 示例参数
def forward(self, x, emb, text_length, layer_id=0):
"""
前向传播
Args:
x: 输入张量 (B, seq_len, hidden_size)
emb: 时间嵌入 (B, time_embed_dim)
text_length: 文本序列长度
layer_id: 层ID
Returns:
输出张量 (B, seq_len, hidden_size)
"""
# 专家AdaLN调制
adaln_output = self.expert_adaln(x, emb, text_length, layer_id)
vision_params = adaln_output['vision_params']
text_params = adaln_output['text_params']
text_hidden = adaln_output['text_hidden_states']
img_hidden = adaln_output['img_hidden_states']
# 注意力分支
shift_msa, scale_msa, gate_msa = vision_params[:3]
text_shift_msa, text_scale_msa, text_gate_msa = text_params[:3]
# 应用AdaLN到注意力输入
img_attn_input = modulate(self.norm1(img_hidden), shift_msa, scale_msa)
text_attn_input = modulate(self.norm1(text_hidden), text_shift_msa, text_scale_msa)
attn_input = torch.cat([text_attn_input, img_attn_input], dim=1)
# 多头注意力
B, seq_len, _ = attn_input.shape
qkv = self.qkv(attn_input).reshape(B, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 应用3D RoPE(仅对视觉部分)
if seq_len > text_length:
q_text, q_img = q[:, :, :text_length], q[:, :, text_length:]
k_text, k_img = k[:, :, :text_length], k[:, :, text_length:]
q_img, k_img = self.rope_3d(q_img, k_img, seq_len - text_length)
q = torch.cat([q_text, q_img], dim=2)
k = torch.cat([k_text, k_img], dim=2)
# 计算注意力
scale = (self.hidden_size // self.num_heads) ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = F.softmax(attn, dim=-1)
out = attn @ v
out = out.transpose(1, 2).reshape(B, seq_len, self.hidden_size)
out = self.proj(out)
# 分离输出并应用门控
text_attn_out = out[:, :text_length]
img_attn_out = out[:, text_length:]
text_hidden = text_hidden + text_gate_msa.unsqueeze(1) * text_attn_out
img_hidden = img_hidden + gate_msa.unsqueeze(1) * img_attn_out
# MLP分支
shift_mlp, scale_mlp, gate_mlp = vision_params[3:]
text_shift_mlp, text_scale_mlp, text_gate_mlp = text_params[3:]
img_mlp_input = modulate(self.norm2(img_hidden), shift_mlp, scale_mlp)
text_mlp_input = modulate(self.norm2(text_hidden), text_shift_mlp, text_scale_mlp)
mlp_input = torch.cat([text_mlp_input, img_mlp_input], dim=1)
mlp_out = self.mlp(mlp_input)
text_mlp_out = mlp_out[:, :text_length]
img_mlp_out = mlp_out[:, text_length:]
text_hidden = text_hidden + text_gate_mlp.unsqueeze(1) * text_mlp_out
img_hidden = img_hidden + gate_mlp.unsqueeze(1) * img_mlp_out
return torch.cat([text_hidden, img_hidden], dim=1)
class CogVideoX(nn.Module):
"""CogVideoX主模型"""
def __init__(self, hidden_size=1152, num_layers=28, num_heads=16,
text_hidden_size=4096, time_embed_dim=512):
super().__init__()
self.hidden_size = hidden_size
self.time_embed_dim = time_embed_dim
# 时间嵌入
self.time_embed = nn.Sequential(
nn.Linear(hidden_size, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim)
)
# 文本投影
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
# Transformer层
self.layers = nn.ModuleList([
CogVideoXBlock(hidden_size, num_heads, time_embed_dim)
for _ in range(num_layers)
])
# 输出层
self.norm_final = nn.LayerNorm(hidden_size)
self.proj_out = nn.Linear(hidden_size, hidden_size)
def forward(self, video_latents, text_embeds, timesteps):
"""
前向传播
Args:
video_latents: 视频潜在表示 (B, seq_len, hidden_size)
text_embeds: 文本嵌入 (B, text_len, text_hidden_size)
timesteps: 时间步 (B,)
Returns:
输出潜在表示 (B, seq_len, hidden_size)
"""
# 时间嵌入
t_emb = self.time_embed(timestep_embedding(timesteps, self.hidden_size))
# 文本投影
text_embeds = self.text_proj(text_embeds)
text_length = text_embeds.shape[1]
# 拼接文本和视频
x = torch.cat([text_embeds, video_latents], dim=1)
# 通过Transformer层
for i, layer in enumerate(self.layers):
x = layer(x, t_emb, text_length, i)
# 输出处理(只处理视频部分)
video_out = x[:, text_length:]
video_out = self.norm_final(video_out)
video_out = self.proj_out(video_out)
return video_out
def timestep_embedding(timesteps, dim, max_period=10000):
"""时间步嵌入"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# 测试代码
if __name__ == '__main__':
# 创建测试数据
batch_size = 2
seq_len = 1024 # 视频序列长度
text_len = 77 # 文本序列长度
hidden_size = 1152
text_hidden_size = 4096
# 创建模型
model = CogVideoX(hidden_size=hidden_size, num_layers=28, num_heads=16,
text_hidden_size=text_hidden_size, time_embed_dim=512)
# 创建输入数据
video_latents = torch.randn(batch_size, seq_len, hidden_size)
text_embeds = torch.randn(batch_size, text_len, text_hidden_size)
timesteps = torch.randint(0, 1000, (batch_size,))
# 前向传播
output = model(video_latents, text_embeds, timesteps)
# 测试3D VAE
vae = VAE3D(in_channels=3, latent_channels=16)
video_input = torch.randn(batch_size, 3, 16, 64, 64) # (B, C, T, H, W)
recon, mean, logvar = vae(video_input)
# 打印结果
print('视频潜在表示尺寸:', video_latents.size())
print('文本嵌入尺寸:', text_embeds.size())
print('输出尺寸:', output.size())
print('参数数量:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')
print('3D VAE输入尺寸:', video_input.size())
print('3D VAE重建尺寸:', recon.size())