微信扫码
与创始人交个朋友
我要投稿
01
前言
本文主要分享我们在大模型知识蒸馏上的实验尝试,以及所取得的实验效果提升。我们在5月份完成了该实验,但由于各种原因,实验分享?️了两个月。由于彼时Qwen2尚未发布,且我们的训练资源有限,所以我们选择将Qwen1.5-32B-Chat-AWQ蒸馏到Qwen1.5-14B中。
在AlpacaEval 2.0和MT-Bench评测集中,我们得到以下两个主要的实验结果:
使用同一份训练数据,蒸馏得到的模型大幅优于直接SFT的模型。
蒸馏可弥补数据质量导致的差距,蒸馏得到的模型比官方Qwen1.5-14B-Chat模型有比较明显的提升。
下表是我们蒸馏的14B模型与官方Qwen1.5-14B-Chat在AplacaEval 2.0评测集中的评测结果。
这表明大模型知识蒸馏的有效性,相较于直接SFT,知识蒸馏能够进一步提升模型的性能,可作为大模型压缩和加速推理的有效手段。
近期谷歌开源的Gemma-2-9B也使用了知识蒸馏的方法,业内的许多闭源大模型必然也使用了知识蒸馏这一手段,将千亿大模型蒸馏到更小的模型,在保证性能的前提下,以提升模型的推理速度。
我们蒸馏的模型权重如下,更多实验细节详见下文:
https://hf-mirror.com/YeungNLP/firefly-qwen1.5-en-14b-alpha
训练代码基于Firefly项目:
https://github.com/yangjianxin1/Firefly
02
知识蒸馏简述
生成式大模型在许多任务中都取得了非常优异的表现,并且改变了NLP的范式。但“大”也导致了模型的推理瓶颈,出现模型推理速度慢,显存占用多等问题。在不改变模型结构和权重的前提下,可通过多卡推理、KV Cache、MQA、GQA、MLA、Page Attention等手段来缓解大模型的推理瓶颈。
知识蒸馏也是一种广泛使用的压缩大模型参数规模,提高模型推理速度,降低显存占用的方法。普遍的做法是让学生模型的logits拟合教师模型的logits,也可以进一步蒸馏教师模型的中间层的logits以及attention矩阵等。核心思想是尽可能地丰富学生模型训练时的监督信号,避免仅使用hard label这一单一的训练目标。
知识蒸馏并不属于特别新的技术,在bert时代已经有很多相关的研究和应用落地,一般是将较大的bert蒸馏到一个较小的bert或者textcnn、lstm等轻量级模型结构,使得模型能够在cpu上进行推理。蒸馏后得到的学生模型往往比直接finetune的效果更优,并且能够很大程度地保留教师模型的能力。
知识蒸馏的相关论文:
论文:Distilling the Knowledge in a Neural Network
链接:https://arxiv.org/pdf/1503.02531
论文:DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
链接:https://arxiv.org/pdf/1910.01108
论文:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
链接:https://arxiv.org/pdf/1903.12136
论文:Patient Knowledge Distillation for BERT Model Compressio
链接:https://arxiv.org/pdf/1908.09355
03
蒸馏Qwen1.5
在本次实验中,我们将Qwen1.5-32B-Chat-AWQ蒸馏到Qwen1.5-14B模型结构中,使用Qwen1.5-14B的模型权重对学生模型进行初始化。我们采用ultrachat-200k作为训练数据,在sft和蒸馏时,均采用QLoRA训练策略。sft时的最大输入长度为2048,distill时的最大输入长度为1024。
在本次实验中,我们对比了不同的训练策略对模型性能的影响,对比实验一共包含四个模型:
Qwen1.5-14B-Chat:Qwen官方开源的chat模型。
ours-sft:对Qwen1.5-14B进行sft。
ours-distill:使用Qwen1.5-14B的权重初始化学生模型,Qwen1.5-32B-Chat-AWQ作为教师模型,仅蒸馏最后一层输出的logits。
ours-distill-sft:可以认为是ours-sft与ours-distill的结合,训练损失由sft loss与distill loss两部分组成,两者加权求和得到最终的训练loss。
我们在AlpacaEval 2.0和MT-Bench中,使用GPT-4o对上述四个模型进行自动评测。
MT-Bench包含单轮与多轮两类评测任务,评测结果如下表所示,可得出如下结论:
相较于直接SFT,蒸馏能够带来大幅的性能提升。
distill-sft的效果不如直接distill。我们猜测这可能不是普适的结论,可能是ultrachat的数据质量不如Qwen官方数据所造成的,若采用Qwen官方所使用的训练数据,distill-sft的训练效果可能会优于直接distill,但这一猜想有待通过实验验证。
在单轮对话评测中,直接蒸馏的模型优于官方的Qwen1.5-14B-Chat,但在多轮对话评测中,却出现了相反的现象。我们猜测可能是训练数据或者训练的最大输入长度导致的。
AlpacaEval 2.0评测集包含805条评测数据,5类评测任务,我们直接评测两个模型之间的胜负率,没有平局。在所有评测子任务中,我们蒸馏的模型的胜率均高于官方模型,总胜率为52.17%:47.83%。
另一个有趣的现象,无论从闭源和开源数据的差异来看,还是从MT-Bench中ours-sft显著弱于Qwen1.5-14B-Chat的表现来分析,我们都可以合理地认为,我们采用的ultrachat-200k的数据质量与Qwen官方的训练数据有比较大的差距,但通过蒸馏的手段,我们可以在一份相对“较弱”的数据上,取得更好的表现。
本实验还有许多待探索的点,包括但不限于,评测学生模型相较于教师模型的性能损失,学生模型的性能与教师模型性能的变化关系,进一步探索蒸馏时的训练目标对学生模型的影响等。
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-09-18
2024-07-02
2024-07-18
2024-07-09
2024-07-15
2024-08-14
2024-07-26
2024-07-10
2024-07-10
2024-10-17
2024-12-25
2024-11-20
2024-11-13
2024-10-31
2024-10-29
2024-10-16
2024-09-19
2024-08-28