本文主要记录不同的SD模型代码中实现的一些细节,如text2img,img2img,inpaint等等
以下代码来自SD1.5: https://github.com/runwayml/stable-diffusion
推理代码
1. 文生图
第一步:生成随机的latent feature (n,4,64,64);n为生成的图片个数;
第二步:对于prompt用clip生成特征,正向提示词的特征(n,77,768),反向提示词的特征(n,77,768)。后续用公式得到最终的噪声:
第三步:生成时间步伐的特征(n,320);
第四步:latent feature放入UNet中,用cnn和多头注意提取特征并下采样,每次都要先加上+时间特征,然后和提示词特征用交叉注意力生成新的特征;
第五步:最终UNet输出的是噪声(n,4,64,64);
第六步:根据之前文章的公式对latent feature (n,4,64,64)去燥,得到新的latent feature (n,4,64,64);不断的采样得到最终的latent feature
第七步:最终的latent feature 放到VAE解码器中,放大8倍得到输出图片(n,3,512,512);
2. 图生图
第一步:与文生图不同,这里的latent feature (n,4,64,64)不是随机生成的,是输入图片用VAE编码器得到的;注意这里有个系数strength,0表示对latent feature不加噪声,1表示加很多噪声完全摧毁原有的信息;
第二步:对于prompt用clip生成特征,正向提示词的特征(n,77,768),反向提示词的特征(n,77,768)。
第三步:生成时间步伐的特征(n,320);
第四步:latent feature放入UNet中,用cnn和多头注意提取特征并下采样,每次都要先加上时间特征,然后和提示词特征用交叉注意力生成新的特征;
第五步:最终UNet输出的是噪声(n,4,64,64);
第六步:根据之前文章的公式对latent feature (n,4,64,64)去燥,得到新的latent feature (n,4,64,64);不断的采样得到最终的latent feature
第七步:最终的latent feature 放到VAE解码器中,放大8倍得到输出图片(n,3,512,512);
3. Inpainting
第一步:生成随机的latent feature (n,4,64,64);n为生成的图片个数;
第二步:得到mask img (n,3,512,512)和mask(n,1,512,512);
第三步:mask下采样到(n,1,64,64);mask img用编码器得到(n,4,64,64);把这两个cat起来得到c_cat (n,5,64,64);
第四步:对于prompt用clip生成特征,正向提示词的特征(n,77,768),反向提示词的特征(n,77,768);
第五步:生成时间步伐的特征(n,320);
第六步:把c_cat (n,5,64,64)和latent feature (n,4,64,64) 连接起来得到网络的输入 (n,9,64,64);也就是说输入UNet的xc是图像特征加上latent feature;而cc是prompt的特征,out = self.diffusion_model(xc, t, context=cc)
;如下图所示:4个latent feature+4个mask img feature+1 mask;
第七步:把上面这些输入UNet中,用cnn和多头注意提取特征并下采样,每次都要先加上时间特征,然后和提示词特征用交叉注意力生成新的特征;
第八步:最终UNet同样输出噪声(n,4,64,64);
第九步:根据之前文章的公式对latent feature (n,4,64,64)去燥,得到新的latent feature (n,4,64,64);不断的采样得到最终的latent feature;
第十步:最终的latent feature 放到VAE解码器中,放大8倍得到输出图片(n,3,512,512);
以下代码来自SDXL: https://github.com/Stability-AI/generative-models
训练:
- 文生图
a. 读取cifar10数据集,图片(n,3,32,32)和类别(n,);
b. 提取n个类别的embedding,cond=(n,128);
c. 生成噪声(n,3,32,32);对图像加噪得到input;
d. 对sigma处理得到c_skip, c_out, c_in, c_noise;如下图计算得到输入数据;
e. 在网络中,对于类别的 embedding,直接加到时间步长的emb上就好;
f. 后续就是正常的交叉注意力,得到输出(n,3,64,64);
g. 最后就是输入图像和网络输出计算mse;
推理:
-
文生图
a. 文本特征用两个clip编码器得到特征后拼接起来得到(n,77,2048),作为cross attention;
b. 把目标分辨率、裁剪的左上角坐标进行余弦编码,然后把它们的特征cat起来得到vector(n,2816);
c. 随机生成(n,4,128,128)的latent feature;
d. 把上面的latent feature作为网络的输入x,文本特征(n,77,2048)作为context;vector作为y,y和时间步长特征相加;
e. 然后在UNet的每个block中都会加上+上面的emb(相当于每个block都学到了目标分辨率、步长等特征);并且用cross attention学习文本特征context。
f. 最终得到噪声(n,4,128,128);
g. 用euler采样公式减去噪声,得到这次去燥后的latent feature(n,4,128,128);
h. 用解码器得到 第一阶段 sdxl_base 模型下的生成图片(1024,1024,3);
i. 如果有第二阶段sdxl_refine,其实就是步骤g得到的latent feature再做多次去噪,最后再用解码器得到sdxl_refine模型下的解码图像; -
图生图
a. 其他步骤都和文生图一样。不一样的点在于:上面c步骤是随机生成的latent feature作为网络的输入,而这里是图片经过VAE encode得到了1个latent feature加上随机的latent feature作为网络的输入;
其他
- unet_added_cond_kwargs = {“text_embeds”: pooled_text_embeds, “time_ids”: add_time_ids},这个是文本的全局特征和分辨率,加到步长ebemding中,然后在每个block中+上emb即可;
- 图片特征提取器encoder_hid_proj,提取到图片的特征,然后和文本cat起来,后面就可以做cross attention了;