AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


大模型知识蒸馏的两种方式
发布日期:2024-09-07 07:23:15 浏览次数: 2957 来源:NetRookie

      上个月llama3.1的405B已经发布,除了感叹开源模型效果的厉害之外,另一个普遍的感受就是,跑不动,根本跑不动,没资源,就算能训练,也部署不起。所以很多人就自然而然关注到了知识蒸馏,通过将大模型能力迁移到小模型能力上。于是大概调研了下,本文主要是对清华的《MiniLLM:Knowledge Distillation of Large Language Models》和Meta的《Distilling System2 into System1》一些解读,刚好他们分别作为白盒蒸馏和黑盒蒸馏的一个典型代表。

       在知乎搜了一下minillm相关的文章,如《 吃果冻不吐果冻皮:大模型知识蒸馏概述 》总结性的介绍了下minillm的逆向kl散度的思路。即最小化前向 Kullback-Leibler 散度 (KLD) 的挑战为教师分布中不太可能的区域出现概率过高,从而在自由运行生成过程中导致不可能的样本 。为了解决这个问题,MINILLM 选择最小化逆向 KLD。这种方法可以防止学生高估教师分布中的低概率区域,从而提高生成样本的质量。但具体原因只在论文中才更清楚,于是部分细节整理如下,本文主要对FKL和RKL差异以及从强化学习视角看MiniLLM做一些介绍,其他论文细节没有涉及太多。

MiniLLM蒸馏

Motivation

前向KL散度倾向于学习mean-seeking,反向kl散度学习mode-seeking

首先需要明确KL散度的非对称性质,即前向和后向是当前仅当两个分布完全相等时才等价的,然后我们分开看一下两个kl散度的具体公式。了解前向KL散度和KL散度分别会导致mode seeking 和 mean seeking 产生的原因在于:

  • 当p(x)较大时,qθ(x) 也需要比较大且比p(x)相对更大,否则公式右边很大的情况下,FKL整体就无法达到最小;

  • 当p(x)较小时,p(x) 在 log 外趋于0占主导,FKL整体总是能比较小,跟qθ(x)关系较小。所以在优化的时候,qθ(x) 会覆盖p(x)的所有mode,即便此时有可能导致高估 p(x) 很小的部分,对应上述图中的橙色部分。

  • 当qθ(x) 较大时,为了在优化时候降低RKL,p(x) 必须较大,因此 p(x) 概率最大的 mode 也要对应 qθ(x) 概率最大的地方,p(x) 概率很小的地方必须对应 qθ(x) 概率为0的地方,也就是说 qθ(x) 拟合了 p(x) 概率最大的部分。对应前述图中的绿色部分。

  • 当 qθ(x) 等于0时,p(x) 取什么样的值都不影响优化。

由此可以看下MiniLLM具体的方法图:

RKL和Inverse RL的等价的数学推导

论文中的另一个视角个人觉得特别好,就是将RKL和逆强化学习进行对比,并给出了数学说明,可以看一下

公式说明:这里的公式序号均来自论文本身,目的是结合论文一起看可能更好,不破坏原有公式顺序。


既然可以这么类比,RKL约等于逆强化学习,FKL等价于模仿学习,而在实际应用和理论说明中,逆强化学习的效果都会比模仿学习更优,虽然更加难以训练,但其泛化性能,理论上限肯定会更高,所以结论是MiniLLM的RKL理论上是更优的。 模仿学习和逆强化学习这个说明可以查看:https://www.zhihu.com/question/470949607/answer/2450111740?utm_id=0

实际怎么训练

上面两个部分,其实都在说明,MiniLLM理论证明上是更优的蒸馏方法。所以我们可以去进行大胆尝试。实际的训练过程,类似于RLHF的训练方式,教师模型在训练中只推理,作为奖励信号去训练模型。作者也提供了类似ranking loss的更简单的平替方式去优化,对比传统Bert时代的蒸馏方法都会有提升。感谢作者!

# https://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166
# 这里是实际蒸馏loss的计算
def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits):
with torch.no_grad():
teacher_model.eval()
teacher_outputs = teacher_model(**model_batch, use_cache=False) # 教师模型推理
teacher_logits = teacher_outputs.logits # 获取教师分布logits
if args.model_parallel:
distil_losses = mpu.parallel_soft_cross_entropy_loss(logits.float(), teacher_logits.float())
distil_losses = distil_losses.view(-1)
loss_mask = no_model_batch["loss_mask"].view(-1)
distil_loss = (distil_losses * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) #
inf_mask = torch.isinf(logits)
#log_softmax实际上是在教师和学生的交叉熵;交叉熵损失在形式上等价于KL散度减去一个常数项(分布P 的熵)在最小化KL散度时可以忽略
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (no_model_batch["label"] != -100).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

return distil_loss

一般来说实际训练中还会加上sft数据的loss确保不跑偏,类似rlhf中的reference model的作用
output = self.model(inputs, attention_mask=attention_mask, return_output=True)
sft_loss = self.loss_fn(output.logits, labels)

需要注意教师模型和学生模型需要使用同源的模型。即相同的tokenizer,对于国产模型来说,qwen、deepseek、yi等都有相同tokenizer不同尺寸的模型可供选择。

System2到System1蒸馏

整体说明:

人类认知系统中的两种推理系统,系统1和系统2,系统1被认为是无意识的,能够快速识别和迅速判断,也叫做快思考,系统2被认为是处理复杂问题如数学和逻辑问题,需要深思熟虑,也叫做慢思考。

在大模型中,可以将中间的流程如多次调用大模型、中间的思考tokens类比为深思熟虑的过程,这些方法如cot、RaR等等带来更好的推理效果,但与此同时,耗时问题会导致这些方法很难用于生产落地。于是很多方法都在尝试将系统2的效果蒸馏到系统1当中(毕竟自2023年 gpt4出来后,应该有非常多的黑盒蒸馏gpt4数据训练到各家系统中;还有很久之前llama2的ghost attention:在每一句中都加入system prompt让 Llama 2 有效地遵循多轮指令,都是一些蒸馏的有效形式)。

这篇论文的主要与之前差异点在于,显式的提出System2的推理能力蒸馏到System1中,并做了很多实验进行验证。可以理解为论文提供了非常好的一种数据合成的范式,通过使用这些数据进行指令微调等方法,提升System1的推理能力。

以下几个公式是对System1和System2的形式化说明:

也就是说,通过上述公式3可以得到的大量训练数据,但是实际会存在质量问题。论文主要通过一致性标准进行过滤。

  • 输出一致性:输入不变,对输出进行N次采样,通过投票实现,少数服从多数

  • 输入扰动下的一致性:输出不变,对输入增加扰动,比如选择题改顺序但答案没变化,不一致则过滤

但猜测实际可能有更多更精细化的方式实现。

然后就是四种方式在不同数据集上的效果,我觉得给出Prompt可能是最好的方法体现形式

RePhrase And Respond Distillation

Prompt:

"{question}"\nRephrase and expand the question, and respond.

让模型先改写,改写可能提供更丰富的文本信息,然后再回答,能让大模型用自己的知识体现理解问题,回答问题。

System2 Attention Distillation

让大模型过滤无效信息,去除有偏信息和不相干上下文,然后再改写基础上进行回答

Branch-Solve-Merge Distillation

Chain Of Thought Distillation

论文通过这四种System2的方式,蒸馏到System1当中,做了很多实验,结果就不一一贴了,都是差的也不会发paper,总结下来,整体是有效的,如RaR蒸馏可用于澄清任务指令相关任务、S2A能有效提升有偏任务,Branch-Solve-Merge蒸馏能作为LLM-Judge评估任务,但是在复杂推理任务上的蒸馏,目前还做得不好。这可能也是一个共识,需要持续研究。

总结,不管是黑盒蒸馏,还是白盒蒸馏,都是现如今非常好的将更大模型的知识注入到较小模型中去的方式,不断提升小模型的知识密度,这样可以再更多的落地场景中应用。期待这个方向后续更多的工作。



53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询