日期:2025/04/01 22:29来源:未知 人气:53
基于SDXL 1.0 进行渐进式和对抗式蒸馏,1024生图分辨率,时间缩短为1步~8步,开源LORA和UNet权重
在技术上,SDXL-Lightning是基于Progressive Distillation和Adversarial Distillation来提升蒸馏效果。目前放出的模型包含1步,2步,4步和8步的蒸馏模型,可以直接生成1024x1024图像,相比之前的只能生成512x512图像的SDXL-Turbo有巨大优势
文生图Demo:https://huggingface.co/spaces/AP123/SDXL-Lightning
实时生图Demo:https://huggingface.co/spaces/radames/Real-Time-Text-to-Image-SDXL-Lightning
学习到的常微分方程流由前向调度、损失函数和模型容量决定。给定有限的训练样本,底层数据分布是模棱两可的。最大似然估计(MLE)是一种分布,它仅将偶数概率分配给观察到的样本,而将其他任何地方的概率分配给零。如果模型具有无限容量,它将学习这种最大似然估计和过拟合的流程,以始终生成观测样本并且不生成新数据。在实践中,扩散模型可以生成新数据,因为神经网络不是完美的学习者。当模型用于多步生成时,它是堆叠的,并且具有更高的 Lipschitz 常数和更多的非线性,以近似更复杂的分布。但是,当模型在几步生成中使用时,它不再具有相同的容量来近似相同的分布。尽管初始噪声变化很小,但扩散模型的结果可能会有非常剧烈的变化,但蒸馏模型具有更平滑的潜在遍历。这就解释了为什么 MSE 损失的蒸馏会产生模糊的结果。学生模型根本没有能力与老师相匹配。我们发现其他距离指标,如 L1 和感知损失,也会产生不良结果。
对于渐进式蒸馏的基本方式,是计算teacher-student之间的MSE损失,teacher是使用多个step达到的结果,作为蒸馏的模型,蒸馏到更少步数的students中,一旦学生模型收敛,它就被用作教师模型,并重复蒸馏过程。从理论上讲,它可以生成一步生成模型。
不同于计算当前流位置下的梯度,模型蒸馏改变模型预测的目标,直接让其预测下一个更远的流位置。通过训练一个学生网络直接预测老师网络完成了多步推理的后的结果。这样的策略可以大幅减少所需的推理步骤数量。通过反复应用这个过程,可以进一步降低推理步骤的数量。这种方法被先前的研究称之为渐进式蒸馏。
在实际操作中,学生网络往往难以精确预测未来的流位置。误差随着每一步的累积而放大,导致在少于 8 步推理的情况下,模型产生的图像开始变得模糊不清。
文章解决方案:使用对抗性判别器代替teacher-student之间的MSE损失。
建立对抗性判别器,计算来自teacher在输入xt和条件c下产生的x(t-ns)的概率,使用非饱和的对抗损失,交替训练判别器和学生模型,鼓励students模型的预测结果x'(t-ns)更接近于teacher模型的预测结果x(t-ns)。
xt上的条件对于保持 ODE 流很重要。这是因为教师生成的xt−ns是由xt确定的。通过向判别器提供xt−ns和xt,判别器可以学习底层的ODE流程,并且学生也必须遵循相同的流程来欺骗判别器。
通俗来说,即
不强求学生网络精确匹配教师网络的预测,而是让学生网络在概率分布上与教师网络保持一致。换言之,学生网络被训练来预测一个概率上可能的位置,即使这个位置并不完全准确,不会对它进行惩罚。这个目标是通过对抗训练来实现的,引入了一个额外的判别网络来帮助实现学生网络和教师网络输出的分布匹配。
基于SDXL的UNet网络做相关改进后作为判别器。预训练权重作为初始化,然后在后续的训练中训练整个判别器。
对抗性目标函数是鼓励预测既敏锐又保持动态,但这并不能改变学生没有足够的能力与老师完美匹配的事实。使用 MSE时,它表现出模糊的结果。但在对抗性的目标下,它表现了“Janus”的伪影问题。
如2所示,教师模型有时可以对相邻的噪声输入产生剧烈的布局变化,但学生模型没有相同的能力来做出如此剧烈的变化。因此,对抗性损失牺牲了语义正确性,以保持清晰度和布局,表现出具有连体头部和身体特征的伪影。语义正确性比人类偏好的模式覆盖更重要。因此,在用原来的对抗目标进行训练后,作者放宽了这种流的保持要求。具体来说,
在没有条件的情况下进一步微调模型
xt:
针对这一目标进行微调可以有效地去除 “Janus”伪影,同时在实践中在很大程度上仍保留原始流程。
因此,在渐进式蒸馏的每个阶段,首先使用有条件目标进行训练,然后使用该无条件目标进行微调。
由于无条件目标仅涉及每个样本的质量,因此使用跳级teacher来提炼一步和两步模型,以进一步保持质量并减少错误累积。
先前的工作表明,通用的扩散时间表是有缺陷的。具体来说,
在训练期间,时间表在t=T时没有达到纯噪声,但在推理过程中给出了纯噪声,从而导致差异。
不幸的是,SDXL 使用了这个有缺陷的时间表。在大量的推理步骤下,这种效果不太明显,但对于几步的生成步骤中尤其有害。
规避该问题的一种巧妙方法是在训练期间将纯噪声 ϵ 硬交换为t=T的模型输入
。这样一来,模型被训练为期望纯噪声作为t=T输入,仍然在推理时使用旧时间表的以避免奇异性。
蒸馏的程序
首先,直接进行从128步到32步的蒸馏,并有MSE损失
然后,切换到使用对抗性损失来按以下顺序提炼步数:32 → 8 → 4 → 2 → 1 。在每个阶段,首先使用条件目标进行训练,然后使用无条件目标进行训练
在每个阶段,首先使用两个目标使用LoRA进行训练,然后合并LoRA并使用无条件目标进一步训练整个UNet。发现对整个UNet进行微调可以获得更好的性能,而LoRA模块可用于其他基本型号。LoRA设置与 LCM-LoRA相同,后者在所有卷积和线性权重上都使用秩64,但输入和输出卷积以及共享时间嵌入线性层除外。在鉴别器上不使用LoRA。在每个阶段重新初始化鉴别器。
训练数据
:LAION、COYO,分辨率大于1024,美学评分大于5.5,进行清晰度过滤,text prompt清洗
训练成本
:64 * A100(80G),bs=512,lr=1e-5,etc.
与以前的方法(Turbo 和 LCM)相比,生成的图像在结构和细节上有显著改进,并且更忠实于原始生成模型的风格和布局。
4步和8步模型在32步方面通常可以胜过原始的 SDXL模型
蒸馏 LoRA 模型可以应用于不同的基础模型。蒸馏LoRA能够在很大程度上保持新基础模型的风格和布局
在执行 1 步和 2 步生成时出现越来越多的不良情况(bad case)
随着推理步骤数量的减少,质量有所下降。
论文
:https://arxiv.org/pdf/2402.13929.pdf
HuggingFace
:https://huggingface.co/ByteDance/SDXL-Lightning
与其他方案比较