MAE课程设计案例:斯坦福CS231n作业中的掩码自编码器实现
引言:计算机视觉中的自监督学习痛点与解决方案
你是否还在为计算机视觉任务中标注数据匮乏而烦恼?是否在寻找一种能够利用海量无标签图像进行高效预训练的方法?本文将以斯坦福大学CS231n课程设计为背景,详细介绍如何基于掩码自编码器(Masked Autoencoder, MAE)实现一个高性能的视觉预训练模型。通过本文,你将掌握:
- MAE的核心原理与实现细节
- 如何从零开始构建MAE模型
- 在PyTorch中高效实现MAE的关键技巧
- MAE在图像重建任务中的应用与可视化
- 课程设计中的常见问题与解决方案
MAE核心原理与架构设计
MAE的工作原理概述
掩码自编码器(MAE)是一种基于Transformer的自监督学习框架,其核心思想是通过随机掩码输入图像的大部分区域,然后训练模型从少量可见的图像块中重建原始图像。这种方法不仅能够学习到丰富的视觉表征,还大大提高了训练效率。
MAE的工作流程可以分为以下几个关键步骤:
- 将输入图像分割为固定大小的图像块(Patch)
- 随机掩码大部分图像块(通常掩码比例为75%)
- 使用Transformer编码器处理可见的图像块
- 使用Transformer解码器从编码的可见块中重建原始图像
- 通过计算重建图像与原始图像的损失来更新模型参数
MAE架构详解
MAE的整体架构由三个主要部分组成:编码器、解码器和损失函数。
编码器结构
编码器的主要作用是对可见的图像块进行特征提取,其结构如下:
关键组件说明:
- PatchEmbed:将图像分割为块并进行线性嵌入
- 位置嵌入(Positional Embedding):为每个图像块添加位置信息
- Transformer块:由多头自注意力和MLP组成的核心特征提取单元
- 随机掩码:实现对输入图像块的随机掩码
解码器结构
解码器的主要作用是从编码的可见块中重建原始图像,其结构如下:
关键组件说明:
- decoder_embed:将编码器输出映射到解码器维度
- mask_token:用于填充掩码位置的可学习参数
- decoder_pos_embed:解码器的位置嵌入
- decoder_blocks:解码器的Transformer块
- decoder_pred:将解码器输出映射回图像块空间
损失函数
MAE使用像素级的均方误差(MSE)作为损失函数,但只计算掩码区域的损失,公式如下:
$$\mathcal{L} = \frac{1}{N \cdot L_{mask}} \sum_{i=1}^{N} \sum_{j \in M_i} | \hat{x}{i,j} - x{i,j} |^2$$
其中,$N$是批量大小,$L_{mask}$是每个样本的掩码块数量,$M_i$是第$i$个样本的掩码位置集合,$\hat{x}{i,j}$是重建的图像块,$x{i,j}$是原始图像块。
MAE的PyTorch实现详解
环境准备与依赖安装
在开始实现MAE之前,需要确保安装以下依赖:
# 克隆项目仓库
git clone https://gitcode.com/gh_mirrors/ma/mae
cd mae
# 创建并激活虚拟环境
conda create -n mae_cs231n python=3.8
conda activate mae_cs231n
# 安装依赖包
pip install torch torchvision timm matplotlib numpy
核心代码实现
1. 图像分块与嵌入
首先实现图像分块与嵌入模块:
class PatchEmbed(nn.Module):
"""将图像分割为补丁并进行线性嵌入"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
return x
2. MAE主模型实现
下面实现MAE的核心模型类:
class MaskedAutoencoderViT(nn.Module):
"""带Vision Transformer骨干的掩码自编码器"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
super().__init__()
# MAE编码器特定部分
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # 固定的sin-cos嵌入
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# MAE解码器特定部分
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # 固定的sin-cos嵌入
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # 解码器到补丁
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
3. 随机掩码实现
随机掩码是MAE的核心组件之一,实现如下:
def random_masking(self, x, mask_ratio):
"""
执行每个样本的随机掩码,通过每个样本的洗牌实现。
通过对随机噪声进行argsort来完成每个样本的洗牌。
x: [N, L, D],序列
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # 噪声在[0, 1]范围内
# 对每个样本的噪声进行排序
ids_shuffle = torch.argsort(noise, dim=1) # 升序:小的保留,大的移除
ids_restore = torch.argsort(ids_shuffle, dim=1)
# 保留第一个子集
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# 生成二进制掩码:0表示保留,1表示移除
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# 反洗牌以获取二进制掩码
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
4. 编码器前向传播
编码器的前向传播实现:
def forward_encoder(self, x, mask_ratio):
# 嵌入补丁
x = self.patch_embed(x)
# 添加位置嵌入(不含cls token)
x = x + self.pos_embed[:, 1:, :]
# 掩码:长度 -> 长度 * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# 附加cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# 应用Transformer块
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
5. 解码器前向传播
解码器的前向传播实现:
def forward_decoder(self, x, ids_restore):
# 嵌入令牌
x = self.decoder_embed(x)
# 向序列追加掩码令牌
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # 不含cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # 反洗牌
x = torch.cat([x[:, :1, :], x_], dim=1) # 附加cls token
# 添加位置嵌入
x = x + self.decoder_pos_embed
# 应用Transformer块
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# 预测器投影
x = self.decoder_pred(x)
# 移除cls token
x = x[:, 1:, :]
return x
6. 损失函数实现
MAE的损失函数实现:
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0表示保留,1表示移除,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], 每个补丁的平均损失
loss = (loss * mask).sum() / mask.sum() # 移除补丁的平均损失
return loss
模型初始化与权重设置
MAE的参数初始化对于模型性能至关重要:
def initialize_weights(self):
# 初始化
# 通过sin-cos嵌入初始化(并冻结)pos_embed
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# 像nn.Linear一样初始化patch_embed(而不是nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm的trunc_normal_(std=.02)实际上等同于normal_(std=0.02),因为截止值太大(2.)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
# 初始化nn.Linear和nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# 我们使用xavier_uniform,遵循官方JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
MAE模型的使用与训练
模型实例化
MAE提供了几种预定义的模型配置,可以直接实例化使用:
def mae_vit_base_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_huge_patch14_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
# 设置推荐的架构
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # 解码器:512维,8个块
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # 解码器:512维,8个块
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # 解码器:512维,8个块
使用示例:
# 创建基础版MAE模型
model = mae_vit_base_patch16(img_size=224, in_chans=3, norm_pix_loss=False)
# 打印模型参数量
print(f"MAE模型参数量: {sum(p.numel() for p in model.parameters()):,}")
数据准备
MAE可以使用无标签数据进行训练,这里以ImageNet数据集为例:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据变换
transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集(使用无标签数据)
dataset = datasets.ImageFolder(root='path/to/imagenet/train', transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)
训练过程实现
MAE的训练过程实现:
import torch.optim as optim
from tqdm import tqdm
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 定义优化器和学习率调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# 训练循环
num_epochs = 100
mask_ratio = 0.75 # 掩码比例
model.train()
for epoch in range(num_epochs):
total_loss = 0
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
for images, _ in progress_bar:
images = images.to(device)
# 前向传播
loss, pred, mask = model(images, mask_ratio)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
# 计算平均损失
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
# 更新学习率
lr_scheduler.step()
# 保存模型
if (epoch + 1) % 10 == 0:
torch.save(model.state_dict(), f"mae_model_epoch_{epoch+1}.pth")
MAE的图像重建可视化
为了直观理解MAE的工作效果,我们可以可视化其图像重建结果。
图像重建函数
实现图像补丁化和反补丁化函数:
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
重建可视化代码
使用matplotlib可视化原始图像、掩码和重建结果:
import matplotlib.pyplot as plt
import numpy as np
def visualize_reconstruction(model, img, mask_ratio=0.75):
"""可视化MAE的图像重建结果"""
model.eval()
with torch.no_grad():
img = img.unsqueeze(0).to(device)
loss, pred, mask = model(img, mask_ratio)
pred = model.unpatchify(pred)
# 将张量转换为图像格式
img = img.cpu().squeeze().permute(1, 2, 0)
pred = pred.cpu().squeeze().permute(1, 2, 0)
# 反归一化
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
img = img * std + mean
pred = pred * std + mean
# 确保像素值在[0, 1]范围内
img = np.clip(img.numpy(), 0, 1)
pred = np.clip(pred.numpy(), 0, 1)
# 创建掩码图像
mask = mask.cpu().squeeze()
mask = mask.reshape(-1, 1).repeat(1, model.patch_embed.patch_size[0]**2 * 3)
mask = model.unpatchify(mask.unsqueeze(0)).squeeze().permute(1, 2, 0)
mask = np.clip(mask.numpy(), 0, 1)
masked_img = img * (1 - mask) + np.ones_like(img) * mask
# 绘制结果
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img)
axes[0].set_title('原始图像')
axes[0].axis('off')
axes[1].imshow(masked_img)
axes[1].set_title(f'掩码图像 (掩码比例: {mask_ratio*100}%)')
axes[1].axis('off')
axes[2].imshow(pred)
axes[2].set_title('重建图像')
axes[2].axis('off')
plt.tight_layout()
plt.show()
# 使用示例
# 加载一张测试图像
test_img = datasets.ImageFolder(root='path/to/test/images', transform=transform)[0][0]
# 可视化重建结果
visualize_reconstruction(model, test_img, mask_ratio=0.75)
课程设计中的常见问题与解决方案
计算资源优化
MAE模型较大,训练时可能面临显存不足的问题,可采用以下优化方法:
# 1. 使用混合精度训练
scaler = torch.cuda.amp.GradScaler()
# 在训练循环中
with torch.cuda.amp.autocast():
loss, pred, mask = model(images, mask_ratio)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 2. 使用梯度累积
accumulation_steps = 4
# 在训练循环中
loss = loss / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
超参数调优
MAE的性能受多个超参数影响,以下是推荐的超参数范围:
超参数 | 推荐范围 | 说明 |
---|---|---|
掩码比例 | 0.6-0.8 | 较高的掩码比例通常能获得更好的特征,但训练难度增加 |
学习率 | 1e-4-5e-4 | 使用AdamW优化器时的初始学习率 |
权重衰减 | 0.05-0.1 | 控制模型过拟合 |
批大小 | 尽可能大 | 在显存允许范围内,批大小越大越好 |
编码器深度 | 12-32 | 更深的编码器能学习更复杂的特征 |
解码器深度 | 4-12 | 解码器通常比编码器浅以提高效率 |
模型评估方法
评估MAE的性能可以从以下几个方面进行:
-
重建质量评估:
from skimage.metrics import peak_signal_noise_ratio, structural_similarity def evaluate_reconstruction(model, dataloader): psnr_scores = [] ssim_scores = [] model.eval() with torch.no_grad(): for images, _ in dataloader: images = images.to(device) loss, pred, mask = model(images, mask_ratio=0.75) pred = model.unpatchify(pred) # 计算PSNR和SSIM for img, p in zip(images, pred): img = img.cpu().permute(1, 2, 0).numpy() p = p.cpu().permute(1, 2, 0).numpy() psnr = peak_signal_noise_ratio(img, p) ssim = structural_similarity(img, p, multichannel=True) psnr_scores.append(psnr) ssim_scores.append(ssim) return np.mean(psnr_scores), np.mean(ssim_scores) # 使用示例 psnr, ssim = evaluate_reconstruction(model, test_dataloader) print(f"平均PSNR: {psnr:.2f}, 平均SSIM: {ssim:.4f}")
-
下游任务评估: 将MAE预训练的权重迁移到分类任务,评估特征质量:
# 提取编码器作为特征提取器 encoder = model.patch_embed + model.blocks + model.norm # 在分类任务上微调 classifier = nn.Sequential( nn.Linear(embed_dim, 512), nn.ReLU(), nn.Linear(512, num_classes) ) # 仅微调分类器 for param in encoder.parameters(): param.requires_grad = False # 训练分类器...
总结与展望
本文详细介绍了在斯坦福CS231n课程设计中实现掩码自编码器(MAE)的完整过程,包括MAE的核心原理、架构设计、PyTorch实现细节、训练过程和可视化方法。通过本课程设计,你不仅能够深入理解MAE这一前沿的自监督学习方法,还能掌握在实际项目中实现和应用复杂深度学习模型的关键技能。
MAE作为一种高效的自监督学习框架,在计算机视觉领域有着广泛的应用前景。未来可以进一步探索以下方向:
- 探索MAE在视频、3D点云等其他视觉模态上的应用
- 研究更有效的掩码策略,如语义感知掩码
- 将MAE与其他自监督学习方法结合,进一步提升性能
- 探索MAE在低资源场景下的应用
希望本文能够帮助你顺利完成课程设计,并为你在计算机视觉领域的深入学习奠定基础!
参考资料
- He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2021). Masked autoencoders are scalable vision learners. arXiv preprint arXiv:2111.06377.
- Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
- Stanford CS231n: Convolutional Neural Networks for Visual Recognition. https://cs231n.stanford.edu/
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考