微信扫码
与创始人交个朋友
我要投稿
模型蒸馏技术深度解析,助力资源受限设备性能提升。核心内容:1. 模型蒸馏技术概述及其在深度学习中的应用价值2. 知识蒸馏的核心组件:知识、蒸馏算法、师生架构3. 知识蒸馏流程详解,包括soft targets的作用与影响
最近wsdm cup到了瓶颈,租卡跑算力成本太高,而lmsys比赛的微调结果也没啥可抄的了,所以只能回头看看top方案,研究了一下阳哥的《Distill is all you need》,和第二名tascj对于训练推理的科技与狠活,有些感觉,伴随着deepseek的大火,蒸馏和强化学习又被端上了台面,我对强化学习暂时没什么兴趣,不过蒸馏跟我最近看的内容相关,在网上搜了一圈关于deepseek针对蒸馏的策略,好像没有过多内容介绍,于是想着总结找到的一些资料。
模型蒸馏即知识蒸馏(Knowledge Distillation),是一种模型压缩和加速技术。在深度学习中,大型深度神经网络虽性能优异,但因计算复杂度高、存储需求大,难以部署在资源受限设备上。模型蒸馏通过构建师生架构,让小的学生模型学习大的教师模型的知识,使学生模型在保持较小规模的同时,尽可能接近教师模型的性能。其核心组件包括知识(如教师模型的 logits、中间层特征等)、蒸馏算法(用于指导知识转移)和师生架构(决定知识传递方式)。
这里可以看比较主流的一张图,出自2021年综述:《Knowledge Distillation: A Survey》,对近年的Distillation做了一个详细概括,Knowledge Distillation的流程可以理解为:
图中除了loss之后会详细说明,唯一的未知点可能在于soft targets,它是经过softmax的下一层级结果logits(原始分数),公式为:
其中是温度系数,从公式中能很明显看出当值较大时,Softmax 输出的概率分布会更加平滑,每个类别的概率值相对更接近;值较小时,概率分布会更尖锐,高概率类别的概率值远高于其他类别。这些 soft targets 会传递给学生模型,学生模型在学习过程中不仅学习真实的hard targets信息,还能从教师模型的 soft targets 中获取类别之间的关联等知识,帮助其更好地训练和泛化。
hard targets 与 soft targets的区别可以从下面的四分类图中很形象的看出:
做知识蒸馏的方式有非常多,从训练方案流程来看,就有离线蒸馏、在线蒸馏和自蒸馏等,从算法更新角度上,还有对抗蒸馏、多教师蒸馏等,这里我就不用豆包在灌水了,想查一大片说明,直接以bert时代的蒸馏开始看。
TinyBERT是一种轻量级的预训练语言模型,由华为和华中科技大学提出。它通过知识蒸馏技术,将BERT模型的知识迁移到一个更小的模型中,从而实现了模型体积的大幅减小和推理速度的提升。在当时,它提出了 两阶段transformer蒸馏方案:在大规模语料上首先进行通用MLM任务的蒸馏,在下游任务时,先学好老师模型,再进行蒸馏,具体如下图:
关于Transformer层蒸馏,主要包括注意力attn的蒸馏和隐藏层hidn的蒸馏:
关于损失函数,TinyBert的蒸馏loss为:
第一项:词向量层损失
第二项:中间层损失
第三项:预测层损失
如果有不清晰的,可以去看论文原文,我就不做过多解释了,上述的内容根据论文开源的github地址,其中对于蒸馏训练的截取部分,可进行一一对照:
# 蒸馏配置
distill_config = DistillationConfig(
# 设置温度系数temperature, tiny-bert论文作者使用1表现最好,一般大于1比较好
temperature=self.temperature,
# 设置ground truth loss权重
hard_label_weight=self.hard_label_weight,
# 设置预测层蒸馏loss(即soft label损失)为交叉熵,并稍微放大其权重
kd_loss_type=self.kd_loss_type, kd_loss_weight=self.kd_loss_weight,
# 配置中间层蒸馏映射
intermediate_matches=[
# 配置hidden蒸馏映射、维度映射
{'layer_T': 0, 'layer_S': 0, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1,
'proj': ['linear', 312, 768]}, # embedding层输出
{'layer_T': 3, 'layer_S': 1, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1,
'proj': ['linear', 312, 768]},
{'layer_T': 6, 'layer_S': 2, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1,
'proj': ['linear', 312, 768]},
{'layer_T': 9, 'layer_S': 3, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1,
'proj': ['linear', 312, 768]},
{'layer_T': 12, 'layer_S': 4, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1,
'proj': ['linear', 312, 768]},
# 配置attention矩阵蒸馏映射,注意layer序号从0开始
{"layer_T": 2, "layer_S": 0, "feature": "attention", "loss": "attention_mse", "weight": 1},
{"layer_T": 5, "layer_S": 1, "feature": "attention", "loss": "attention_mse", "weight": 1},
{"layer_T": 8, "layer_S": 2, "feature": "attention", "loss": "attention_mse", "weight": 1},
{"layer_T": 11, "layer_S": 3, "feature": "attention", "loss": "attention_mse", "weight": 1},
]
)
# 训练配置
optimizer = AdamW(self.student_model.parameters(), lr=self.lr) # 使用大一点的lr
train_config = TrainingConfig(
output_dir=self.student_model_dir, device=self.student_trainer.device,
data_parallel=self.enable_parallel, ckpt_frequency=self.ckpt_frequency # 一个epoch存ckpt_frequency次模型
)
# 配置model中logits hiddens attentions losses的获取方法
def simple_adaptor(batch, model_outputs):
return {
'logits': model_outputs[-1]['logits'], 'hidden': model_outputs[-1]['hiddens'],
'attention': model_outputs[-1]['attentions'], 'losses': model_outputs[1],
}
# 蒸馏
distiller = GeneralDistiller(
train_config=train_config, distill_config=distill_config,
model_T=self.teacher_model, model_S=self.student_model,
adaptor_T=simple_adaptor, adaptor_S=simple_adaptor
)
with distiller:
logger.info('start to knowledge distill ...')
distiller.train(optimizer, train_dataloader, num_epochs=epoch)
logger.info('distill finish')
KL散度的定义是建立在熵(Entropy)的基础上的。此处以离散随机变量为例,若一个离散随机变量的可能取值为,而对应的概率为,则随机变量的熵定义为:
若有两个随机变量,且其概率分布分别为,则相对的相对摘为:
之所以称之为相对熵,是因为其可以通过两随机变量的交叉嫡(Cross-Entropy)以及信息摘推导得到,针对上述离散变量的概率分布而言,其交叉摘定义为:
因此,KL散度或相对熵可通过下式得出:
在上一节中,TinyBERT在设计其蒸馏过程时采用了多种损失函数,包括词向量层损失、中间层损失和预测层损失,在大模型时代下,词向量损失不用多说,因为已经完全做了解耦,如何进行embedding我想看到这里的都知道,中间层损失的不再使用,或者说中间层蒸馏的使用变少,我理解是大模型通常已经具有足够的参数来学习复杂的特征表示,因此它的必要性相对较低,另外就是中间层叠得太厚,所能获得的收益太低,所以不如针对预测层进行相应的改进,那自然,就不得不提本节在介绍的KL散度。
那为什么作为大模型来讲,更多使用KL散度呢?我觉得可以从以下三点考虑:
上述介绍了KL散度的定义,很明显,KL损失不是一个对称形式,即,那么我们可以试图用近似分布来优化该目标:
根据上一小节的概率公式推导,可以计算出反向 (Reverse KL,RKL)为:
正向 (Forward KL,FKL)为:
其中P是teacher,Q是student,在大模型之前,似乎很多人更喜欢用FKL,正向KL散度(FKL)更受青睐的原因可能与其在传统任务上的表现有关。传统分类任务的输出空间相对较小,模式(即分布的峰值)较少,这意味着分布更倾向于单一峰值而非多峰值分布。在这种情况下,FKL表现良好,因为它倾向于让学生模型关注教师模型输出中概率较高的区域,从而产生更准确的样本。然而,对于大型语言模型(LLM)来说,输出空间更加复杂,模式更多,再使用FKL可能导致学生模型关注教师模型输出中概率较低的区域,从而产生不良样本。
如上图所示,教师模型是蓝色曲线,它的输出是可量化的,这里假设为两个高斯波峰,而黄色,是理想情况下,我们认为学生模型可以近似为正态分布来拟合教师曲线,那么会出现两种结果,一种是尽可能多的包括多峰的面积,第二种是直接拟合最高波峰的分布。所以左边是Forward KL,右边是反向。
中间的一些具体推导过程不过多赘述,近年有非常多的论文对该方案做了benchmark,比如说下图是《f-Divergence Minimization for Sequence-Level Knowledge Distillation》一文的数据:
还有《Rethinking Kullback-Leibler Divergence in Knowledge Distillation for Large Language Models》篇的数据和AKL:
另外说明一下,本节内容就是看了作者在知乎发的《LLM的知识蒸馏(KD)应该用Reverse KL?》一文才有想法撰写本节,对于想复现的小伙伴来讲,可以去看这几篇论文的github,作者还给了一些相应的可视化demo。
TRL(Transformer Reinforcement Learning)库是用于后续训练基础模型的综合库,专为使用监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术进行训练后的基础模型而设计。这里我们只看它里面的两种trainer——SFTtrainer和GKDtrainer。
从原理方面来讲:
从损失计算方面来讲:
这两种顺序非常直观,GKDTrainer继承自SFTTrainer,SFTTrainer继承自Trainer。那从SFTtrainer看,它的调用非常简单,trl的readme直接写了一个demo:
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
调用该类后,我又去看了下transformers的trainer,它的损失函数为:
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if (self.label_smoother isnotNoneor self.compute_loss_func isnotNone) and"labels"in inputs:
labels = inputs.pop("labels")
else:
labels = None
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch isnotNone:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels isnotNone:
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
if self.compute_loss_func isnotNone:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and"loss"notin outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss
很显然这部分有非常多的自适应判断,根据我们上一层为SFTtrainer类,并且没有指定loss方法,所以将选用cross-entropy loss作为模型训练参数。
而GKDtrainer类的方式就不一样,由于KL散度是不对称的,在知识蒸馏中使用JSD,Jensen-Shannon Divergence 是基于KL散度改进的更平滑和对称的概率分布度量。论文中给出了其改进的计算公式:
那自然其重写了compute_loss,具体计算为generalized_jsd_loss,代码如下:
def generalized_jsd_loss(
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
):
"""
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
of https://huggingface.co/papers/2306.13649 for the definition.
Args:
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
temperature: Softmax temperature (default: 1.0)
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
Returns:
loss: Scalar tensor with the generalized JSD loss
"""
# Apply temperature scaling
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
# Compute log probabilities for student and probabilities for teacher
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
# Compute the log of the mixture distribution
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
dim=0,
)
# Compute KL divergences using F.kl_div
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
# Compute the Generalized Jensen-Shannon Divergence
jsd = beta * kl_teacher + (1 - beta) * kl_student
# Masking
if labels isnotNone:
mask = labels != -100
jsd = jsd[mask]
# Apply reduction
if reduction == "batchmean":
return jsd.sum() / mask.sum() if labels isnotNoneelse jsd.sum() / (jsd.size(0) * jsd.size(1))
elif reduction == "sum":
return jsd.sum()
elif reduction == "mean":
return jsd.mean()
else:
return jsd
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# compute student output
outputs_student = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# compute teacher output in eval mode
self.teacher_model.eval()
with torch.no_grad():
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs["prompts"].shape[1]
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]
# compute loss
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
labels=shifted_labels,
beta=self.beta,
)
# empty cache
empty_cache()
# Return loss
return (loss, outputs_student) if return_outputs else loss
对于该类好不好用,我也不知道,暂时没用过,只能说从理论来分析,JSD损失和KL损失的区别,不过与SFTtrainer类似,调用方式也很简单,可以跑几次看看情况:
from datasets import load_dataset
import random
from transformers import AutoTokenizer
from trl import (
GKDConfig,
GKDTrainer,
LogCompletionsCallback,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
################
# Training
################
trainer = GKDTrainer(
model=model_config.model_name_or_path,
teacher_model=training_args.teacher_model_name_or_path,
args=training_args,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=test_data,
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)
completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
trainer.add_callback(completions_callback)
trainer.train()
# Save
trainer.save_model(training_args.output_dir)
本节是对阳哥夺冠方案中关于蒸馏部分的经典总结,在这里做一个旁征博引,因为没有算力,具体我也没复现过,不过算是除了写这篇推文的初衷,本来是想做一个top方案亮点汇总,只是因为deepseek的爆火针对其中一个方向做了延展。那话不多说,github原址为:https://github.com/shyoulala/LMSYS_BlackPearl
该仓库的目录结构为:
./model_path # 预训练模型的路径,存放预训练模型的权重和配置文件
./src_fast # 快速训练脚本的存放位置,可能包含简化的训练代码
./src # 完整解决方案的代码目录,包含整个项目的完整训练和处理流程
./data # 数据目录,存放训练数据和其他相关数据
./data/oof # Out-of-Fold 数据目录,可能用于交叉验证的中间结果
./data/processed_data # 处理后的数据目录,存放经过预处理的数据
./data/processed_data/orgemma2fold4 # 训练集,包含用于直接蒸馏的 70b 概率数据(第4折)
./data/processed_data/orgemma2fold2 # 同上,第2折
./data/processed_data/orgemma2fold0 # 同上,第0折
./data/processed_data/orgemma2fold1 # 同上,第1折
./data/processed_data/orgemma2fold3 # 同上,第3折
./data/lmsys-chatbot-arena # 可能存放与 LMSYS Chatbot Arena 相关的数据或资源
./sub # 输出目录,用于存放训练结果、预测结果等
./model_save # 训练模型的保存路径,存放训练完成后的模型文件
./model_save_or # 另一个模型保存路径,可能是用于存放原始模型或特定版本的模型
./model_save_or/v7_ut_gemma_v7_64r128_ddgemma2_16bit # 经过后处理(如蒸馏)的模型版本,可能是 Gemma2-9B 的 16bit 版本
挺难想象的,大模型时代竟然还能做交叉验证,不过lmsys是个三分类任务,依照之前逻辑也没什么问题,该方案主要是用llama3-70B和Qwen2-72B-instruct对gamma2-9B做蒸馏,所有大致流程,都通过run_pipeline.sh有显现:
#!/bin/bash
set -e
qwen_path=../model_path/qwen2_72b
llama_path=../model_path/llama3_70b
gemma_path=../model_path/Gemma2_9b
qwen_path_ut=../model_save/qwen2_4bit_pretrain/epoch_0_model/adapter.bin
llama_path_ut=../model_save/llama3_4bit_pretrain/epoch_0_model/adapter.bin
gemma_path_ut=../model_save/gemma2_4bit_pretrain/epoch_0_model/adapter.bin
fold=$1
echo run:${fold}
# train llama3 70b
sh run_fintune.sh llama3 ${llama_path} ${llama_path_ut} ${fold}
# predict train logits
python predict_train.py ${llama_path} ../model_save/llama3_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/llama3fold${fold}/train.parquet ../data/oof/llama3fold${fold}_train.parquet
# train qwen2 70b
sh run_fintune.sh qwen2 ${qwen_path} ${qwen_path_ut} ${fold}
# predict train logits
python predict_train.py ${qwen_path} ../model_save/qwen2_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/qwen2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet
# merge logits
python merge_logits.py ../data/processed_data/gemma2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet ../data/oof/llama3fold${fold}_train.parquet ../data/processed_data/gemma2fold${fold}/train_logits.parquet
# distill fintune gemma2-9b
sh run_fintune_16bit_distill.sh gemma2 ${gemma_path} ${gemma_path_ut} ${fold}
中间几步有挺多有趣的操作,比如是如何做post train的,以及最后merge logits,这里仅谈蒸馏之前的merge lora,因为代码足够简单:
import time
from dataclasses import dataclass
import pickle
import torch
import sklearn
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from transformers import Gemma2ForSequenceClassification, GemmaTokenizerFast, BitsAndBytesConfig
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
lora_dir = '../model_save/gemma2fold0_16bit_load_fintune/best_val_loss_model/adapter.bin'
d1 = torch.load(lora_dir)
lora_dir = '../model_save/gemma2fold1_16bit_load_fintune/best_val_loss_model/adapter.bin'
d2 = torch.load(lora_dir)
lora_dir = '../model_save/gemma2fold2_16bit_load_fintune/best_val_loss_model/adapter.bin'
d3 = torch.load(lora_dir)
lora_dir = '../model_save/gemma2fold3_16bit_load_fintune/best_val_loss_model/adapter.bin'
d4 = torch.load(lora_dir)
lora_dir = '../model_save/gemma2fold4_16bit_load_fintune/best_val_loss_model/adapter.bin'
d5 = torch.load(lora_dir)
d = {}
for k, v in d1.items():
v = d1[k] + d2[k] + d3[k] + d4[k] + d5[k]
v = v / 5.
d[k] = v
torch.save(d, "../model_save/final_adapter.bin")
代码上可见,就是对经过5次交叉验证的gamma模型权重做了加权平均合并,但我看discussion很多人提到了,它们同样想到了该方案,不过效果并不好,似乎是这些权重还需要做方差评估,如果方差过大反而会拖累加权后的结果,感兴趣有卡有算力的能进行尝试,我就不过多提了。
回到正题,最终是先得到了llama3和Qwen的模型输出,那么蒸馏即是需要考虑这两者的结果,所以蒸馏损失选择了:
loss_fun = nn.CrossEntropyLoss()
divergence_loss_fn = nn.KLDivLoss(reduction='batchmean')
cos_loss_fn = nn.CosineEmbeddingLoss()
outputs = model(batch['input_ids'], use_cache=False) # predict gemma2
logits = outputs.logits
grads = batch['grads']
grads1 = batch['grads'][:, :3] # qwen2
grads2 = batch['grads'][:, 3:] # llama3
labels = batch['labels']
loss_ce = loss_fun(logits, labels)
loss_grad1 = divergence_loss_fn(
F.log_softmax(logits / T, dim=1),
F.softmax(grads1 / T, dim=1)
)
cos_loss1 = cos_loss_fn(F.softmax(grads1 / T, dim=1), F.softmax(logits / T, dim=1),
torch.ones(logits.size()[0]).to(logits.device))
loss_grad2 = divergence_loss_fn(
F.log_softmax(logits / T, dim=1),
F.softmax(grads2 / T, dim=1)
)
cos_loss2 = cos_loss_fn(F.softmax(grads2 / T, dim=1), F.softmax(logits / T, dim=1),
torch.ones(logits.size()[0]).to(logits.device))
loss = (loss_ce + loss_grad1 + cos_loss1 + loss_grad2 + cos_loss2) / 5.
用数学公式理解,即为交叉熵和KL散度的混合:
这里刚开始我不是很理解,然后问了下deepseek懂了:
为什么同时使用交叉熵损失和 KL 散度损失?
1. 保持监督学习能力
交叉熵损失确保学生模型能够正确预测真实标签,从而保持模型的监督学习能力。如果没有交叉熵损失,学生模型可能会过度依赖教师模型的输出,而忽视真实标签的指导,导致模型在真实数据上的性能下降。
2. 学习教师模型的软目标
KL 散度损失让学生模型学习教师模型的软目标,从而捕捉到教师模型的内部表示和知识。软目标通常包含更多的信息,可以帮助学生模型更好地理解数据的分布和特征。
3. 平衡硬标签和软目标
同时使用交叉熵损失和 KL 散度损失可以平衡硬标签和软目标的贡献。硬标签(真实标签)提供了直接的监督信号,而软目标(教师模型的输出)提供了更多的上下文信息。通过调整两者的权重,可以更好地指导学生模型的学习。
其实我认为以上主要的,是因为教师模型是两个,而不是一个,KL更适合于一个,而两个加入交叉熵我的理解为桥接,更能体现泛化,但具体为啥这样安排,只有跑了才知道,所以根据github的环境说明,有8张A100以上的,可以跑一轮,等待3天以上,观看结果了。
该repo是DeepSeek-R1的开放复现版本,由huggingface的CEO亲自提出并进行,我大致看了一下,它的规划是:
这里重点看step 1,即它使用distilabel来对Deepseek-R1提取蒸馏数据,以下是一个简单demo:
from datasets import load_dataset
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
prompt_template = """\
You will be given a problem. Please reason step by step, and put your final answer within \boxed{}:
{{ instruction }}"""
dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train").select(range(10))
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"# Exchange with another smol distilled r1
with Pipeline(
name="distill-qwen-7b-r1",
description="A pipeline to generate data from a distilled r1 model",
) as pipeline:
llm = vLLM(
model=model_id,
tokenizer=model_id,
extra_kwargs={
"tensor_parallel_size": 1,
"max_model_len": 8192,
},
generation_kwargs={
"temperature": 0.6,
"max_new_tokens": 8192,
},
)
prompt_column = "problem"
text_generation = TextGeneration(
llm=llm,
template=prompt_template,
num_generations=4,
input_mappings={"instruction": prompt_column} if prompt_column isnotNoneelse {}
)
if __name__ == "__main__":
distiset = pipeline.run(dataset=dataset)
distiset.push_to_hub(repo_id="username/numina-deepseek-r1-qwen-7b")
然后将该数据加入了sft中:
def main(script_args, training_args, model_args):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=Falseif training_args.gradient_checkpointing elseTrue,
device_map=get_kbit_device_map() if quantization_config isnotNoneelseNone,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Training
################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no"elseNone,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
从代码上可以看到,这个过程是从教师模型中提取知识,并将其传递给学生模型。在这个特定的情况下,知识不是以软标签的形式直接传递,而是通过生成的推理数据来传递。这种方法通常被称为数据蒸馏(Data Distillation)或示例蒸馏(Example Distillation),它是知识蒸馏的一种变体。
我看到了腾讯科技发布的一场关于DeepSeek的高质量闭门会:比技术更重要的是愿景 ,里面的很多内容可以作为结尾:
文章的最后,因为deepseek火爆的出圈,我也看到了很多各类博主在其上做各种任务,比如写文章或者写诗,其中有一首我很喜欢的,是它们都比deepseek 好,我知道 一文中使用deepseek生成的其一,作为本篇结尾:
春城惊岁晚,梅魂初醒,滇海骤翻银浪。西山素甲,南天冻幕,翠湖暗锁寒香。冰绡裹垂杨,讶螺峰披絮,金马凝霜。万户笙箫,尽收檐角作琳琅。
谁教玉戏蛮乡?遣滕六醉舞,姑射颠狂。谢女絮迷,袁安户掩,争知南诏风光。椒盘冷红妆,想罗裙冰透,画阁炉藏。且待明朝晴暖,花事又铺张。
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2025-01-27
HybridFlow:基于 Ray 构建灵活且高效的 RLHF 编程框架
2025-01-27
ollama 部署 deepseek-r1 70B 模型完整指南
2025-01-26
AI根据接口文档生成服务端模拟工程
2025-01-26
聊聊DeepSeek R1的知识蒸馏与应用思考
2025-01-25
谈谈对DeepSeek-R1的一些理解
2025-01-24
手把手教你AnythingLLM+Ollama+qwen 部署本地大模型
2025-01-24
AI大模型那么强,它是吃什么长大的?
2025-01-24
谁说AI Agent杀死了RPA,UiPath第一个不服!来自头部RPA应用CEO Daniel Dines的深刻洞察
2024-09-18
2024-07-11
2024-07-11
2024-07-26
2024-07-09
2024-06-11
2024-12-29
2024-10-20
2024-07-20
2024-07-12