前言
于 11 月底正式开课的扩散模型课程正在火热进行中,在中国社区成员们的帮助下,我们组织了「抱抱脸中文本地化志愿者小组」并完成了扩散模型课程的中文翻译,感谢 @darcula1993、@XhrLeokk、@hoi2022、@SuSung-boy 对课程的翻译!
如果你还没有开始课程的学习,我们建议你从 第一单元:扩散模型简介 开始。
扩散模型从零到一
这个 Notebook 我们将展示相同的步骤(向数据添加噪声、创建模型、训练和采样),并尽可能简单地在 PyTorch 中从头开始实现。然后,我们将这个「玩具示例」与 diffusers 版本进行比较,并关注两者的区别以及改进之处。这里的目标是熟悉不同的组件和其中的设计决策,以便在查看新的实现时能够快速确定关键思想。
让我们开始吧!
有时,只考虑一些事务最简单的情况会有助于更好地理解其工作原理。我们将在本笔记本中尝试这一点,从“玩具”扩散模型开始,看看不同的部分是如何工作的,然后再检查它们与更复杂的实现有何不同。
你将跟随本文的 Notebook 学习到
损坏过程(向数据添加噪声)
什么是 UNet,以及如何从零开始实现一个极小的 UNet
扩散模型训练
抽样理论
然后,我们将比较我们的版本与 diffusers 库中的 DDPM 实现的区别
对小型 UNet 的改进
DDPM 噪声计划
训练目标的差异
timestep 调节
抽样方法
这个笔记本相当深入,如果你对从零开始的深入研究不感兴趣,可以放心地跳过!
还值得注意的是,这里的大多数代码都是出于说明的目的,我不建议直接将其用于您自己的工作(除非您只是为了学习目的而尝试改进这里展示的示例)。
准备环境与导入:
!pip install -q diffusers
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
数据
在这里,我们将使用一个非常小的经典数据集 mnist 来进行测试。如果您想在不改变任何其他内容的情况下给模型一个稍微困难一点的挑战,请使用 torchvision.dataset
,FashionMNIST 应作为替代品。
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
该数据集中的每张图都是一个数字的 28x28 像素的灰度图,像素值的范围是从 0 到 1。
损坏过程
假设你没有读过任何扩散模型的论文,但你知道这个过程会增加噪声。你会怎么做?
我们可能想要一个简单的方法来控制损坏的程度。那么,如果我们要引入一个参数来控制输入的“噪声量”,那么我们会这么做:
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
如果 amount = 0,则返回输入而不做任何更改。如果 amount = 1,我们将得到一个纯粹的噪声。通过这种方式将输入与噪声混合,我们将输出保持在相同的范围(0 to 1)。
我们可以很容易地实现这一点(但是要注意 tensor 的 shape,以防被广播 (broadcasting) 机制不正确的影响到):
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
让我们来可视化一下输出的结果,以了解是否符合我们的预期:
# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Plottinf the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');
当噪声量接近 1 时,我们的数据开始看起来像纯随机噪声。但对于大多数的噪声情况下,您还是可以很好地识别出数字。你认为这是最佳的吗?
模型
我们想要一个模型,它可以接收 28px 的噪声图像,并输出相同形状的预测。一个比较流行的选择是一个叫做 UNet 的架构。最初被发明用于医学图像中的分割任务,UNet 由一个“压缩路径”和一个“扩展路径”组成。“压缩路径”会使通过该路径的数据被压缩,而通过“扩展路径”会将数据扩展回原始维度(类似于自动编码器)。模型中的残差连接也允许信息和梯度在不同层级之间流动。
一些 UNet 的设计在每个阶段都有复杂的 blocks,但对于这个玩具 demo,我们只会构建一个最简单的示例,它接收一个单通道图像,并通过下行路径上的三个卷积层(图和代码中的 down_layers)和上行路径上的 3 个卷积层,在下行和上行层之间具有残差连接。我们将使用 max pooling 进行下采样和 nn.Upsample
用于上采样。某些比较复杂的 UNets 的设计会使用带有可学习参数的上采样和下采样 layer。下面的结构图大致展示了每个 layer 的输出通道数:

代码实现如下:
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
])
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer n the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer n the activation function
return x
我们可以验证输出 shape 是否如我们期望的那样与输入相同:
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
torch.Size([8, 1, 28, 28])
该网络有 30 多万个参数:
sum([p.numel() for p in net.parameters()])
309057
您可以尝试更改每个 layer 中的通道数或尝试不同的结构设计。
训练模型
那么,模型到底应该做什么呢?同样,对这个问题有各种不同的看法,但对于这个演示,让我们选择一个简单的框架:给定一个损坏的输入 noisy_x
,模型应该输出它对原本 x
的最佳猜测。我们将通过均方误差将预测与真实值进行比较。
我们现在可以尝试训练网络了。
获取一批数据
添加随机噪声
将数据输入模型
将模型预测与干净图像进行比较,以计算 loss
更新模型的参数
你可以自由进行修改来尝试获得更好的结果!
# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# How many runs through the data should we do?
n_epochs = 3
# Create the network
net = BasicUNet()
net.to(device)
# Our loss finction
loss_fn = nn.MSELoss()
# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# Keeping a record of the losses for later viewing
losses = []
# The training