模型蒸馏是深度学习中的一种技术,其中训练一个更小、更简单的模型(“学生”)来复制一个更大、更复杂的模型(“教师”)的行为。 其目标是将教师模型捕获的知识转移到更高效的形式,而不会显着降低性能。 这个过程对于在资源受限的环境中部署模型特别有用,例如移动设备或边缘计算系统,在这些系统中,原始模型的计算成本或内存占用可能令人望而却步。
核心思想是使用原始训练数据以及教师模型的输出来训练学生模型。 学生不仅仅依赖于硬标签(例如,分类任务中的类别索引),而是从教师的“软目标”中学习——教师对可能的类别生成的概率分布。 例如,在图像分类任务中,教师可能会为正确的类别(例如,“猫”)分配高概率,但也会为语义相关的类别(例如,“狗”或“老虎”)分配较小的概率。 这些细微的输出提供了比硬标签更丰富的指导,有助于学生模型更好地泛化。 常见的实现是使用损失函数,该函数结合了学生相对于真实标签的误差及其与教师的软预测的差异。 像温度缩放(调整 softmax 函数的锐度)这样的技术通常用于使教师的输出分布在训练期间更具信息性。
一个实际的例子是将像 BERT 这样的大型语言模型压缩成一个更小的变体(例如,DistilBERT)。 学生模型模仿教师在文本分类等任务上的预测,同时使用更少的层和参数。 同样,在计算机视觉中,ResNet-50 模型的知识可以被提炼成轻量级的 MobileNet,以便在移动设备上进行更快的推理。 好处包括减少延迟、降低内存使用率和更容易部署,但通常需要在大小和准确性之间进行权衡。 通过专注于教师学习到的模式而不是单独的原始数据,蒸馏可以实现保持原始模型大部分能力的有效模型。