第四章 Diffusers 实战
安装Difffusers 库
pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow
扩散模型调度器
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
定义扩散模型
from diffusers import UNet2DModel
def model():
model = UNet2DModel(
sample_size = 240,
in_channels = 4,
out_channels = 4,
layers_per_block = 2,
block_out_channels = (64,128,128,256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
)
)
return model
创建扩散模型训练循环
import torch.utils.data.dataset
import torchvision
from dataset import dataset_brats_2D
from torchvision import transforms
from diffusers import DDPMScheduler
import model
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import time
if __name__ == "__main__":
device =