MAE课程设计案例:斯坦福CS231n作业中的掩码自编码器实现

MAE课程设计案例:斯坦福CS231n作业中的掩码自编码器实现

【免费下载链接】mae PyTorch implementation of MAE https//arxiv.org/abs/2111.06377 【免费下载链接】mae 项目地址: https://gitcode.com/gh_mirrors/ma/mae

引言:计算机视觉中的自监督学习痛点与解决方案

你是否还在为计算机视觉任务中标注数据匮乏而烦恼?是否在寻找一种能够利用海量无标签图像进行高效预训练的方法?本文将以斯坦福大学CS231n课程设计为背景,详细介绍如何基于掩码自编码器(Masked Autoencoder, MAE)实现一个高性能的视觉预训练模型。通过本文,你将掌握:

  • MAE的核心原理与实现细节
  • 如何从零开始构建MAE模型
  • 在PyTorch中高效实现MAE的关键技巧
  • MAE在图像重建任务中的应用与可视化
  • 课程设计中的常见问题与解决方案

MAE核心原理与架构设计

MAE的工作原理概述

掩码自编码器(MAE)是一种基于Transformer的自监督学习框架,其核心思想是通过随机掩码输入图像的大部分区域,然后训练模型从少量可见的图像块中重建原始图像。这种方法不仅能够学习到丰富的视觉表征,还大大提高了训练效率。

mermaid

MAE的工作流程可以分为以下几个关键步骤:

  1. 将输入图像分割为固定大小的图像块(Patch)
  2. 随机掩码大部分图像块(通常掩码比例为75%)
  3. 使用Transformer编码器处理可见的图像块
  4. 使用Transformer解码器从编码的可见块中重建原始图像
  5. 通过计算重建图像与原始图像的损失来更新模型参数

MAE架构详解

MAE的整体架构由三个主要部分组成:编码器解码器损失函数

编码器结构

编码器的主要作用是对可见的图像块进行特征提取,其结构如下:

mermaid

关键组件说明:

  • PatchEmbed:将图像分割为块并进行线性嵌入
  • 位置嵌入(Positional Embedding):为每个图像块添加位置信息
  • Transformer块:由多头自注意力和MLP组成的核心特征提取单元
  • 随机掩码:实现对输入图像块的随机掩码
解码器结构

解码器的主要作用是从编码的可见块中重建原始图像,其结构如下:

mermaid

关键组件说明:

  • decoder_embed:将编码器输出映射到解码器维度
  • mask_token:用于填充掩码位置的可学习参数
  • decoder_pos_embed:解码器的位置嵌入
  • decoder_blocks:解码器的Transformer块
  • decoder_pred:将解码器输出映射回图像块空间
损失函数

MAE使用像素级的均方误差(MSE)作为损失函数,但只计算掩码区域的损失,公式如下:

$$\mathcal{L} = \frac{1}{N \cdot L_{mask}} \sum_{i=1}^{N} \sum_{j \in M_i} | \hat{x}{i,j} - x{i,j} |^2$$

其中,$N$是批量大小,$L_{mask}$是每个样本的掩码块数量,$M_i$是第$i$个样本的掩码位置集合,$\hat{x}{i,j}$是重建的图像块,$x{i,j}$是原始图像块。

MAE的PyTorch实现详解

环境准备与依赖安装

在开始实现MAE之前,需要确保安装以下依赖:

# 克隆项目仓库
git clone https://gitcode.com/gh_mirrors/ma/mae
cd mae

# 创建并激活虚拟环境
conda create -n mae_cs231n python=3.8
conda activate mae_cs231n

# 安装依赖包
pip install torch torchvision timm matplotlib numpy

核心代码实现

1. 图像分块与嵌入

首先实现图像分块与嵌入模块:

class PatchEmbed(nn.Module):
    """将图像分割为补丁并进行线性嵌入"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        return x
2. MAE主模型实现

下面实现MAE的核心模型类:

class MaskedAutoencoderViT(nn.Module):
    """带Vision Transformer骨干的掩码自编码器"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # MAE编码器特定部分
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # 固定的sin-cos嵌入

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # MAE解码器特定部分
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # 固定的sin-cos嵌入

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)  # 解码器到补丁

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()
3. 随机掩码实现

随机掩码是MAE的核心组件之一,实现如下:

def random_masking(self, x, mask_ratio):
    """
    执行每个样本的随机掩码,通过每个样本的洗牌实现。
    通过对随机噪声进行argsort来完成每个样本的洗牌。
    x: [N, L, D],序列
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))
    
    noise = torch.rand(N, L, device=x.device)  # 噪声在[0, 1]范围内
    
    # 对每个样本的噪声进行排序
    ids_shuffle = torch.argsort(noise, dim=1)  # 升序:小的保留,大的移除
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # 保留第一个子集
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # 生成二进制掩码:0表示保留,1表示移除
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # 反洗牌以获取二进制掩码
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore
4. 编码器前向传播

编码器的前向传播实现:

def forward_encoder(self, x, mask_ratio):
    # 嵌入补丁
    x = self.patch_embed(x)

    # 添加位置嵌入(不含cls token)
    x = x + self.pos_embed[:, 1:, :]

    # 掩码:长度 -> 长度 * mask_ratio
    x, mask, ids_restore = self.random_masking(x, mask_ratio)

    # 附加cls token
    cls_token = self.cls_token + self.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(x.shape[0], -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)

    # 应用Transformer块
    for blk in self.blocks:
        x = blk(x)
    x = self.norm(x)

    return x, mask, ids_restore
5. 解码器前向传播

解码器的前向传播实现:

def forward_decoder(self, x, ids_restore):
    # 嵌入令牌
    x = self.decoder_embed(x)

    # 向序列追加掩码令牌
    mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # 不含cls token
    x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # 反洗牌
    x = torch.cat([x[:, :1, :], x_], dim=1)  # 附加cls token

    # 添加位置嵌入
    x = x + self.decoder_pos_embed

    # 应用Transformer块
    for blk in self.decoder_blocks:
        x = blk(x)
    x = self.decoder_norm(x)

    # 预测器投影
    x = self.decoder_pred(x)

    # 移除cls token
    x = x[:, 1:, :]

    return x
6. 损失函数实现

MAE的损失函数实现:

def forward_loss(self, imgs, pred, mask):
    """
    imgs: [N, 3, H, W]
    pred: [N, L, p*p*3]
    mask: [N, L], 0表示保留,1表示移除,
    """
    target = self.patchify(imgs)
    if self.norm_pix_loss:
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6)**.5

    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # [N, L], 每个补丁的平均损失

    loss = (loss * mask).sum() / mask.sum()  # 移除补丁的平均损失
    return loss

模型初始化与权重设置

MAE的参数初始化对于模型性能至关重要:

def initialize_weights(self):
    # 初始化
    # 通过sin-cos嵌入初始化(并冻结)pos_embed
    pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
    self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

    decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
    self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

    # 像nn.Linear一样初始化patch_embed(而不是nn.Conv2d)
    w = self.patch_embed.proj.weight.data
    torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    # timm的trunc_normal_(std=.02)实际上等同于normal_(std=0.02),因为截止值太大(2.)
    torch.nn.init.normal_(self.cls_token, std=.02)
    torch.nn.init.normal_(self.mask_token, std=.02)

    # 初始化nn.Linear和nn.LayerNorm
    self.apply(self._init_weights)

def _init_weights(self, m):
    if isinstance(m, nn.Linear):
        # 我们使用xavier_uniform,遵循官方JAX ViT:
        torch.nn.init.xavier_uniform_(m.weight)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

MAE模型的使用与训练

模型实例化

MAE提供了几种预定义的模型配置,可以直接实例化使用:

def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

# 设置推荐的架构
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # 解码器:512维,8个块
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # 解码器:512维,8个块
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # 解码器:512维,8个块

使用示例:

# 创建基础版MAE模型
model = mae_vit_base_patch16(img_size=224, in_chans=3, norm_pix_loss=False)

# 打印模型参数量
print(f"MAE模型参数量: {sum(p.numel() for p in model.parameters()):,}")

数据准备

MAE可以使用无标签数据进行训练,这里以ImageNet数据集为例:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据变换
transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集(使用无标签数据)
dataset = datasets.ImageFolder(root='path/to/imagenet/train', transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)

训练过程实现

MAE的训练过程实现:

import torch.optim as optim
from tqdm import tqdm

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 定义优化器和学习率调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# 训练循环
num_epochs = 100
mask_ratio = 0.75  # 掩码比例

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for images, _ in progress_bar:
        images = images.to(device)
        
        # 前向传播
        loss, pred, mask = model(images, mask_ratio)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    # 计算平均损失
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    # 更新学习率
    lr_scheduler.step()
    
    # 保存模型
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f"mae_model_epoch_{epoch+1}.pth")

MAE的图像重建可视化

为了直观理解MAE的工作效果,我们可以可视化其图像重建结果。

图像重建函数

实现图像补丁化和反补丁化函数:

def patchify(self, imgs):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    p = self.patch_embed.patch_size[0]
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
    return x

def unpatchify(self, x):
    """
    x: (N, L, patch_size**2 *3)
    imgs: (N, 3, H, W)
    """
    p = self.patch_embed.patch_size[0]
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]
    
    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
    return imgs

重建可视化代码

使用matplotlib可视化原始图像、掩码和重建结果:

import matplotlib.pyplot as plt
import numpy as np

def visualize_reconstruction(model, img, mask_ratio=0.75):
    """可视化MAE的图像重建结果"""
    model.eval()
    with torch.no_grad():
        img = img.unsqueeze(0).to(device)
        loss, pred, mask = model(img, mask_ratio)
        pred = model.unpatchify(pred)
        
        # 将张量转换为图像格式
        img = img.cpu().squeeze().permute(1, 2, 0)
        pred = pred.cpu().squeeze().permute(1, 2, 0)
        
        # 反归一化
        mean = torch.tensor([0.485, 0.456, 0.406])
        std = torch.tensor([0.229, 0.224, 0.225])
        img = img * std + mean
        pred = pred * std + mean
        
        # 确保像素值在[0, 1]范围内
        img = np.clip(img.numpy(), 0, 1)
        pred = np.clip(pred.numpy(), 0, 1)
        
        # 创建掩码图像
        mask = mask.cpu().squeeze()
        mask = mask.reshape(-1, 1).repeat(1, model.patch_embed.patch_size[0]**2 * 3)
        mask = model.unpatchify(mask.unsqueeze(0)).squeeze().permute(1, 2, 0)
        mask = np.clip(mask.numpy(), 0, 1)
        masked_img = img * (1 - mask) + np.ones_like(img) * mask
        
        # 绘制结果
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img)
        axes[0].set_title('原始图像')
        axes[0].axis('off')
        
        axes[1].imshow(masked_img)
        axes[1].set_title(f'掩码图像 (掩码比例: {mask_ratio*100}%)')
        axes[1].axis('off')
        
        axes[2].imshow(pred)
        axes[2].set_title('重建图像')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()

# 使用示例
# 加载一张测试图像
test_img = datasets.ImageFolder(root='path/to/test/images', transform=transform)[0][0]

# 可视化重建结果
visualize_reconstruction(model, test_img, mask_ratio=0.75)

课程设计中的常见问题与解决方案

计算资源优化

MAE模型较大,训练时可能面临显存不足的问题,可采用以下优化方法:

# 1. 使用混合精度训练
scaler = torch.cuda.amp.GradScaler()

# 在训练循环中
with torch.cuda.amp.autocast():
    loss, pred, mask = model(images, mask_ratio)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

# 2. 使用梯度累积
accumulation_steps = 4

# 在训练循环中
loss = loss / accumulation_steps
loss.backward()

if (i + 1) % accumulation_steps == 0:
    optimizer.step()
    optimizer.zero_grad()

超参数调优

MAE的性能受多个超参数影响,以下是推荐的超参数范围:

超参数推荐范围说明
掩码比例0.6-0.8较高的掩码比例通常能获得更好的特征,但训练难度增加
学习率1e-4-5e-4使用AdamW优化器时的初始学习率
权重衰减0.05-0.1控制模型过拟合
批大小尽可能大在显存允许范围内,批大小越大越好
编码器深度12-32更深的编码器能学习更复杂的特征
解码器深度4-12解码器通常比编码器浅以提高效率

模型评估方法

评估MAE的性能可以从以下几个方面进行:

  1. 重建质量评估

    from skimage.metrics import peak_signal_noise_ratio, structural_similarity
    
    def evaluate_reconstruction(model, dataloader):
        psnr_scores = []
        ssim_scores = []
    
        model.eval()
        with torch.no_grad():
            for images, _ in dataloader:
                images = images.to(device)
                loss, pred, mask = model(images, mask_ratio=0.75)
                pred = model.unpatchify(pred)
    
                # 计算PSNR和SSIM
                for img, p in zip(images, pred):
                    img = img.cpu().permute(1, 2, 0).numpy()
                    p = p.cpu().permute(1, 2, 0).numpy()
    
                    psnr = peak_signal_noise_ratio(img, p)
                    ssim = structural_similarity(img, p, multichannel=True)
    
                    psnr_scores.append(psnr)
                    ssim_scores.append(ssim)
    
        return np.mean(psnr_scores), np.mean(ssim_scores)
    
    # 使用示例
    psnr, ssim = evaluate_reconstruction(model, test_dataloader)
    print(f"平均PSNR: {psnr:.2f}, 平均SSIM: {ssim:.4f}")
    
  2. 下游任务评估: 将MAE预训练的权重迁移到分类任务,评估特征质量:

    # 提取编码器作为特征提取器
    encoder = model.patch_embed + model.blocks + model.norm
    
    # 在分类任务上微调
    classifier = nn.Sequential(
        nn.Linear(embed_dim, 512),
        nn.ReLU(),
        nn.Linear(512, num_classes)
    )
    
    # 仅微调分类器
    for param in encoder.parameters():
        param.requires_grad = False
    
    # 训练分类器...
    

总结与展望

本文详细介绍了在斯坦福CS231n课程设计中实现掩码自编码器(MAE)的完整过程,包括MAE的核心原理、架构设计、PyTorch实现细节、训练过程和可视化方法。通过本课程设计,你不仅能够深入理解MAE这一前沿的自监督学习方法,还能掌握在实际项目中实现和应用复杂深度学习模型的关键技能。

MAE作为一种高效的自监督学习框架,在计算机视觉领域有着广泛的应用前景。未来可以进一步探索以下方向:

  • 探索MAE在视频、3D点云等其他视觉模态上的应用
  • 研究更有效的掩码策略,如语义感知掩码
  • 将MAE与其他自监督学习方法结合,进一步提升性能
  • 探索MAE在低资源场景下的应用

希望本文能够帮助你顺利完成课程设计,并为你在计算机视觉领域的深入学习奠定基础!

参考资料

  1. He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2021). Masked autoencoders are scalable vision learners. arXiv preprint arXiv:2111.06377.
  2. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
  3. Stanford CS231n: Convolutional Neural Networks for Visual Recognition. https://cs231n.stanford.edu/

【免费下载链接】mae PyTorch implementation of MAE https//arxiv.org/abs/2111.06377 【免费下载链接】mae 项目地址: https://gitcode.com/gh_mirrors/ma/mae

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值