从0开始知识蒸馏
翻译自 https://keras.io/examples/vision/knowledge_distillation/#train-student-from-scratch-for-comparison
更多关于蒸馏和模型推理加速的知识可参考博客《预训练模型参数量越来越大?这里有你需要的BERT推理加速技术指南》
知识蒸馏简介
知识蒸馏是一个模型压缩的过程,其中训练一个小的(学生)模型来匹配一个大的预训练(教师)模型。通过最小化损失函数将知识从教师模型转移到学生,旨在匹配软化的教师逻辑和真实标签。
通过在 softmax 中应用“温度”缩放函数来软化对数,有效地平滑概率分布并揭示老师学到的类间关系。
设置
1 |
|
构造 Distiller()
类
自定义Distiller()
类,覆盖Model
方法train_step
,test_step
以及compile()
。为了使用蒸馏器,我们需要:
- 已经训好的教师模型
- 要训练的学生模型
- 关于学生预测和 ground-truth 之间差异的学生损失函数
- A distillation loss function, along with a
temperature
, on the difference between the soft student predictions and the soft teacher labels - 一个
alpha
因素加权学生和蒸馏损失 - 学生和(可选)指标的优化器来评估性能
在该train_step
方法中,我们执行教师和学生两者的 forward pass,计算student_loss
和distillation_loss
的加权损失(alpha
与 1 - alpha
),并执行 backward pass。Note: only the student weights are updated, and therefore we only calculate the gradients for the student weights.
在test_step
方法中,我们在提供的数据集上评估学生模型。
1 |
|
创建学生和教师模型
最初,我们创建了一个教师模型和一个较小的学生模型。这两个模型都是卷积神经网络,使用Sequential()
,但可以是任何 Keras 模型。
1 |
|
准备数据集
用于训练教师和提炼教师的数据集是 MNIST,该过程对于任何其他数据集都是等效的,例如CIFAR-10,具有合适的模型选择。学生和教师都在训练集上接受训练,并在测试集上进行评估。
1 |
|
Train the teacher
在知识蒸馏中,我们假设老师是经过培训和固定的。因此,我们首先以通常的方式在训练集上训练教师模型。
1 |
|
Distill teacher to student
我们已经训练了教师模型,我们只需要初始化一个 Distiller(student, teacher)
实例,compile()
它具有所需的损失、超参数和优化器,并将教师提炼给学生。
1 |
|
从头开始训练学生进行比较
我们还可以在没有老师的情况下从头开始训练一个等效的学生模型,以评估通过知识蒸馏获得的性能提升。
1 |
|
如果教师接受了 5 个完整的 epochs 训练,而学生在这个教师身上被提炼了 3 个完整的 epochs,那么在这个例子中,与从头开始训练相同的学生模型相比,甚至与教师本身相比,都得到了性能提升。