Loading [MathJax]/jax/output/CommonHTML/jax.js

2. Stable diffusion——DDPM

1. 逆向过程

逆向过程是将得到的噪声变为原始图片,具体来说就是从 xT 逐步回到原始的图片 x0,数学表示为:
q(xt1|xt)
即已知 xT,而他是通过 t1 步时加噪来的,那么去掉这个噪声后 xt1 是什么样的呢,如果能得到这个概率分布模型,就可以一步步返回去求解到的 x0。具体推导 简单基础入门理解 Denoising Diffusion Probabilistic Model,DDPM 扩散模型。最后得到:

˜μt(xt)=1αt(xt1αt1ˉαt˜zt)

即在时间步 t 时, 学习的模型 ˜μt 的输出, 即对于输入 xt, 模型预测的均值 ˜μt(xt)

其中:

  • αtˉαt 分别是时间步 t 和之前所有时间步的扩散程度系数的积
  • xt 是时间步 t 的观测值, 包含了噪声
  • ˜zt 是一个辅助噪声变量, 服从标准高斯分布 N(0,I)

这个公式的推导利用了扩散过程中 xtx0 之间的关系:

xt=ˉαtx0+1ˉαtϵ

其中 ϵ 是一个标准高斯噪声。

通过一些代数运算, 可以得到上面那个 ˜μt(xt) 的表达式, 它实际上是对 x0 的估计, 也就是去噪后的结果。

在训练过程中, 我们最小化 ˜μt(xt) 与真实的 x0 之间的差异, 从而学习这个模型的参数。一旦训练好, 就可以用它进行逐步去噪和生成新样本了。

2. 模型训练

2. Stable diffusion——DDPM

1) 首先是从真实图像分布 q(x0) 中采样一个原始的清晰图像 x0

2) 随机选择一个扩散时间步 t, 从均匀分布 U(1,T) 中采样, 其中 T 是最大扩散步数。

3) 从标准高斯分布 N(0,I) 中采样一个噪声 ϵ

4) 根据之前推导的关系:

xt=ˉαtx0+1ˉαtϵ

我们可以根据采样得到的 x0, tϵ, 构造出对应的 xt。这个 xt 包含了噪声, 是我们的输入数据。

5) 目标是最小化真实噪声 ϵ 与模型预测噪声之间的均方差:

Loss=||ϵϵθ(xt,t)||2

其中 ϵθ(xt,t) 是神经网络模型的输出, 也就是对于输入 xt 和时间步 t, 模型需要预测出对应的噪声 ϵθ(xt,t)

6) 之所以选择这个损失函数, 是因为根据之前的公式:

xt=ˉαtx0+1ˉαtϵ

如果模型能够精确预测出 ϵ, 那么我们就能够从 xt 中恢复出 x0: 即

x0=xt1ˉαtϵθ(xt,t)ˉαt
所以这个损失函数的本质, 是在训练模型去精确捕获噪声分布, 以便从噪声数据中精确地去噪并重建原始图像。通过最小化这个损失, 我们能够训练出逆扩散 (降噪) 模型的参数 θ。这个损失函数的设计, 充分利用了已知的正向扩散过程的数学关系, 使得模型能够从噪声数据中学习重建原始信号的映射, 是 diffusion 模型训练的关键所在。

3. 模型推理

对应上图的采样

1) 从标准高斯分布中采样得到一个噪声, 记为 xT。由于原始图像 x0 经过 T 次加噪后最终得到的也是一个标准高斯噪声, 因此这里采样得到的噪声我们记为 xT

2) 进行 T 次逆扩散过程, 将图像从高斯噪声 xT 中恢复出来。对于每次逆扩散过程:

3) 随机采样一个标准高斯噪声 z。注意在最后一步时不采样, 令 z=0, 这是一个技巧, 不影响对整体的理解。

4) 通过公式计算得到去噪一次的结果:

xt1=1αt(xt1αt1ˉαtϵθ(xt,t))+σtz

这个式子的理解依赖于重参数化技巧。从分布的角度, 比如从 N(μ,σ2) 中采样得到一个 ϵ, 可以写成:

ϵ=μ+zσ
其中 z 为标准高斯噪声。根据之前的公式, 我们知道高斯分布 q(xt1xt) 的均值 μt 为:

μt=1αt(xt1αt1ˉαt˜zt) 其中 ˜zt 实际上就是网络 ϵθ(xt,t) 能够预测的东西, 直接替换即可。再加上方差 σ2t, 有:

q(xt1xt)N(1αt(xt1αt1ˉαtϵθ(xt,t)),σ2t)
这样通过 T 次这样的采样, 我们就可以从初始的噪声 xT 逐步重构出最终的图像 x0。也就是说我们训练好 diffusion model 之后,只要一个噪声就可以构建一个图像了。

4. 总结

训练过程:

  1. 从真实图像分布 q(x0) 中采样一个原始清晰图像 x0
  2. 从均匀分布 U(1,T) 中随机采样一个时间步 t
  3. 将时间步 t 使用某种编码方式 (如求 sin/cos 值) 嵌入为时间嵌入向量 t
  4. 根据公式 xt=ˉαtx0+1ˉαtϵ 生成对应的噪声图像 xt
  5. 将噪声图像 xt 和时间嵌入 t 作为输入, 送入 U -Net 模型
  6. U-Net 模型输出预测的噪声 ϵθ(xt,t)
  7. 计算预测噪声与真实噪声 ϵ 之间的损失: |ϵϵθ(xt,t)|2
  8. 使用优化器 (如 Adam) 更新 U -Net 模型参数 θ, 最小化损失函数
  9. 重复 1 -8, 不断采样新的图像和时间步, 训练模型直到收敛

推理过程:

  1. 从噪声先验分布 (如标准高斯) 中采样一个初始噪声图像 xT
  2. 对于 t=T,T1,,1,0, 利用训练好的 U -Net 模型进行逐步采样:

xt1=1αt(xt1αt1ˉαtϵθ(xt,t))+σtz

其中 zN(0,I) 是重参数的高斯噪声。

  1. 最终得到的 x0 就是生成的新图像样本

通过上述训练过程,U-Net 模型就能够学习到去噪声 (denoising) 的映射, 使其输出的噪声 ϵθ 足够逼近真实噪声分布, 从而可以用于从噪声生成新图像样本。关键是 U -Net 与时间嵌入向量的结合, 让模型能够学习到任意时间步的去噪映射。

正文完
 
admin
版权声明:本站原创文章,由 admin 2024-03-27发表,共计4501字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。