AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


大模型知识蒸馏指南
发布日期:2025-01-28 13:51:03 浏览次数: 1666 来源:许同学说
推荐语

模型蒸馏技术深度解析,助力资源受限设备性能提升。

核心内容:
1. 模型蒸馏技术概述及其在深度学习中的应用价值
2. 知识蒸馏的核心组件:知识、蒸馏算法、师生架构
3. 知识蒸馏流程详解,包括soft targets的作用与影响

杨芳贤
53A创始人/腾讯云(TVP)最具价值专家

最近wsdm cup到了瓶颈,租卡跑算力成本太高,而lmsys比赛的微调结果也没啥可抄的了,所以只能回头看看top方案,研究了一下阳哥的《Distill is all you need》,和第二名tascj对于训练推理的科技与狠活,有些感觉,伴随着deepseek的大火,蒸馏和强化学习又被端上了台面,我对强化学习暂时没什么兴趣,不过蒸馏跟我最近看的内容相关,在网上搜了一圈关于deepseek针对蒸馏的策略,好像没有过多内容介绍,于是想着总结找到的一些资料。

什么是模型蒸馏?

模型蒸馏即知识蒸馏(Knowledge Distillation),是一种模型压缩和加速技术。在深度学习中,大型深度神经网络虽性能优异,但因计算复杂度高、存储需求大,难以部署在资源受限设备上。模型蒸馏通过构建师生架构,让小的学生模型学习大的教师模型的知识,使学生模型在保持较小规模的同时,尽可能接近教师模型的性能。其核心组件包括知识(如教师模型的 logits、中间层特征等)、蒸馏算法(用于指导知识转移)和师生架构(决定知识传递方式)。

这里可以看比较主流的一张图,出自2021年综述:《Knowledge Distillation: A Survey》,对近年的Distillation做了一个详细概括,Knowledge Distillation的流程可以理解为:

图中除了loss之后会详细说明,唯一的未知点可能在于soft targets,它是经过softmax的下一层级结果logits(原始分数),公式为:

其中是温度系数,从公式中能很明显看出当值较大时,Softmax 输出的概率分布会更加平滑,每个类别的概率值相对更接近;值较小时,概率分布会更尖锐,高概率类别的概率值远高于其他类别。这些 soft targets 会传递给学生模型,学生模型在学习过程中不仅学习真实的hard targets信息,还能从教师模型的 soft targets 中获取类别之间的关联等知识,帮助其更好地训练和泛化。

hard targets 与 soft targets的区别可以从下面的四分类图中很形象的看出:

知识蒸馏有什么意义

  • 实现模型压缩与加速:模型蒸馏能有效压缩模型大小、降低计算复杂度,提升推理速度。如在论文研究中,通过知识蒸馏将大模型知识转移到小模型,在 CIFAR10 和 CIFAR100 数据集上进行实验,结果表明可实现不同深度模型的压缩,使轻量级学生模型在保持较高准确率的同时,显著减少模型参数和计算量,满足在资源受限设备上的部署需求 。
  • 提升模型性能:帮助学生模型学习到教师模型的有用知识,提高自身性能。在视觉识别、自然语言处理、语音识别等多个领域的研究中发现,知识蒸馏可提升模型在复杂任务中的表现。例如在自然语言处理中,对BERT 模型进行知识蒸馏得到的轻量级模型,在保持较高准确率的同时,推理速度大幅提升,能够高效完成多种语言任务 。
  • 解决数据相关问题:在数据稀缺、存在隐私问题或数据难以获取时,模型蒸馏有独特优势。数据无关蒸馏方法可利用教师模型生成合成数据训练学生模型,避免对大量真实数据的依赖。在涉及敏感数据的场景中,多教师蒸馏可让多个教师模型分别处理不同子集数据,监督学生模型训练,既能保护数据隐私,又能完成模型训练。
  • 促进跨领域与跨模态学习:跨模态蒸馏可实现不同模态间的知识转移,帮助模型更好地处理多模态数据。在一些研究中,将 RGB 图像模态的知识转移到深度图像模态,使模型在不同模态下都能取得较好的性能,拓宽了模型的应用范围。
  • 助力终身学习与持续优化:与终身学习结合,模型蒸馏可帮助模型在新任务学习中保留旧知识,避免灾难性遗忘。在不断出现新数据和新任务的场景下,通过知识蒸馏将已有知识传递给新模型,使模型能够持续学习和优化,提升其适应性和泛化能力。

如何做知识蒸馏

做知识蒸馏的方式有非常多,从训练方案流程来看,就有离线蒸馏、在线蒸馏和自蒸馏等,从算法更新角度上,还有对抗蒸馏、多教师蒸馏等,这里我就不用豆包在灌水了,想查一大片说明,直接以bert时代的蒸馏开始看。

unsetunsettinybertunsetunset

TinyBERT是一种轻量级的预训练语言模型,由华为和华中科技大学提出。它通过知识蒸馏技术,将BERT模型的知识迁移到一个更小的模型中,从而实现了模型体积的大幅减小和推理速度的提升。在当时,它提出了 两阶段transformer蒸馏方案:在大规模语料上首先进行通用MLM任务的蒸馏,在下游任务时,先学好老师模型,再进行蒸馏,具体如下图:

关于Transformer层蒸馏,主要包括注意力attn的蒸馏和隐藏层hidn的蒸馏:

关于损失函数,TinyBert的蒸馏loss为:

  1. 第一项:词向量层损失

  • 计算学生词向量和老师词向量的均方误差:
  • 因为的维度末必一致,这里需要参数做映射
  • 第二项:中间层损失

    • 学生第 i 层多头注意力矩阵和老师第 j 层多头注意力矩阵计算MSE, K 为注意力的head数
    • 学生的第 i 层隐层输出和 老师的第 j 层隐层输出计算MSE,用做映射
    • 若学生4层,老师12层,则老师的  (3,6,9,12)  层分别蒸馏到学生的  (1,2,3,4)  层。
    • 中间层的损失由隐层均方误差损失和注意力损失组成:
    • 隐层均方误差损失:
    • 注意力损失:
  • 第三项:预测层损失

    • 学生学习老师的soft label
    • 并计算交叉熵:

    如果有不清晰的,可以去看论文原文,我就不做过多解释了,上述的内容根据论文开源的github地址,其中对于蒸馏训练的截取部分,可进行一一对照:

    # 蒸馏配置
    distill_config = DistillationConfig(
        # 设置温度系数temperature, tiny-bert论文作者使用1表现最好,一般大于1比较好
        temperature=self.temperature,
        # 设置ground truth loss权重
        hard_label_weight=self.hard_label_weight,
        # 设置预测层蒸馏loss(即soft label损失)为交叉熵,并稍微放大其权重
        kd_loss_type=self.kd_loss_type, kd_loss_weight=self.kd_loss_weight,
        # 配置中间层蒸馏映射
        intermediate_matches=[
            # 配置hidden蒸馏映射、维度映射
            {'layer_T'0'layer_S'0'feature''hidden''loss''hidden_mse''weight'1,
             'proj': ['linear'312768]},  # embedding层输出
            {'layer_T'3'layer_S'1'feature''hidden''loss''hidden_mse''weight'1,
             'proj': ['linear'312768]},
            {'layer_T'6'layer_S'2'feature''hidden''loss''hidden_mse''weight'1,
             'proj': ['linear'312768]},
            {'layer_T'9'layer_S'3'feature''hidden''loss''hidden_mse''weight'1,
             'proj': ['linear'312768]},
            {'layer_T'12'layer_S'4'feature''hidden''loss''hidden_mse''weight'1,
             'proj': ['linear'312768]},
            # 配置attention矩阵蒸馏映射,注意layer序号从0开始
            {"layer_T"2"layer_S"0"feature""attention""loss""attention_mse""weight"1},
            {"layer_T"5"layer_S"1"feature""attention""loss""attention_mse""weight"1},
            {"layer_T"8"layer_S"2"feature""attention""loss""attention_mse""weight"1},
            {"layer_T"11"layer_S"3"feature""attention""loss""attention_mse""weight"1},
        ]
    )

    # 训练配置
    optimizer = AdamW(self.student_model.parameters(), lr=self.lr)  # 使用大一点的lr
    train_config = TrainingConfig(
        output_dir=self.student_model_dir, device=self.student_trainer.device,
        data_parallel=self.enable_parallel, ckpt_frequency=self.ckpt_frequency  # 一个epoch存ckpt_frequency次模型
    )

    # 配置model中logits hiddens attentions losses的获取方法
    def simple_adaptor(batch, model_outputs):
        return {
            'logits': model_outputs[-1]['logits'], 'hidden': model_outputs[-1]['hiddens'],
            'attention': model_outputs[-1]['attentions'], 'losses': model_outputs[1],
        }

    # 蒸馏
    distiller = GeneralDistiller(
        train_config=train_config, distill_config=distill_config,
        model_T=self.teacher_model, model_S=self.student_model,
        adaptor_T=simple_adaptor, adaptor_S=simple_adaptor
    )
    with distiller:
        logger.info('start to knowledge distill ...')
        distiller.train(optimizer, train_dataloader, num_epochs=epoch)
        logger.info('distill finish')

    unsetunsetKL散度(**Kullback-Leibler divergence**)unsetunset

    KL散度的定义是建立在熵(Entropy)的基础上的。此处以离散随机变量为例,若一个离散随机变量的可能取值为而对应的概率为,则随机变量的熵定义为:

    若有两个随机变量,且其概率分布分别为,则相对的相对摘为:

    之所以称之为相对熵,是因为其可以通过两随机变量的交叉嫡(Cross-Entropy)以及信息摘推导得到,针对上述离散变量的概率分布而言,其交叉摘定义为:

    因此,KL散度或相对熵可通过下式得出:

    在上一节中,TinyBERT在设计其蒸馏过程时采用了多种损失函数,包括词向量层损失、中间层损失和预测层损失,在大模型时代下,词向量损失不用多说,因为已经完全做了解耦,如何进行embedding我想看到这里的都知道,中间层损失的不再使用,或者说中间层蒸馏的使用变少,我理解是大模型通常已经具有足够的参数来学习复杂的特征表示,因此它的必要性相对较低,另外就是中间层叠得太厚,所能获得的收益太低,所以不如针对预测层进行相应的改进,那自然,就不得不提本节在介绍的KL散度。

    那为什么作为大模型来讲,更多使用KL散度呢?我觉得可以从以下三点考虑:

    1. 知识蒸馏的需求:大模型在进行知识蒸馏时,需要将教师模型的知识传递给学生模型。KL散度能够衡量两个概率分布之间的差异,适合用于衡量教师模型和学生模型之间的输出分布差异。通过最小化KL散度,可以使得学生模型的输出分布尽可能接近教师模型的输出分布。
    2. 考虑分布的整体差异:KL散度不仅考虑了预测分布与真实分布之间的交叉熵,还考虑了真实分布的熵。这使得KL散度能够更全面地衡量两个分布之间的差异,适合用于大模型这种需要精细调整输出分布的场景
    3. 优化目标的一致性:在知识蒸馏中,优化KL散度等价于优化交叉熵。但是,KL散度在某些情况下能够提供更稳定的优化目标,尤其是在教师模型和学生模型的输出分布差异较大时。

    unsetunset前向KL(forward)和后向KL(reverseunsetunset

    上述介绍了KL散度的定义,很明显,KL损失不是一个对称形式,即,那么我们可以试图用近似分布来优化该目标:

    1. Minimizing the forward KL: 
    2. Minimizing the reverse KL: 

    根据上一小节的概率公式推导,可以计算出反向 (Reverse KL,RKL)为:

    正向 (Forward KL,FKL)为:

    其中P是teacher,Q是student,在大模型之前,似乎很多人更喜欢用FKL,正向KL散度(FKL)更受青睐的原因可能与其在传统任务上的表现有关。传统分类任务的输出空间相对较小,模式(即分布的峰值)较少,这意味着分布更倾向于单一峰值而非多峰值分布。在这种情况下,FKL表现良好,因为它倾向于让学生模型关注教师模型输出中概率较高的区域,从而产生更准确的样本。然而,对于大型语言模型(LLM)来说,输出空间更加复杂,模式更多,再使用FKL可能导致学生模型关注教师模型输出中概率较低的区域,从而产生不良样本。

    如上图所示,教师模型是蓝色曲线,它的输出是可量化的,这里假设为两个高斯波峰,而黄色,是理想情况下,我们认为学生模型可以近似为正态分布来拟合教师曲线,那么会出现两种结果,一种是尽可能多的包括多峰的面积,第二种是直接拟合最高波峰的分布。所以左边是Forward KL,右边是反向。

    中间的一些具体推导过程不过多赘述,近年有非常多的论文对该方案做了benchmark,比如说下图是《f-Divergence Minimization for Sequence-Level Knowledge Distillation》一文的数据:

    还有《Rethinking Kullback-Leibler Divergence in Knowledge Distillation for Large Language Models》篇的数据和AKL:

    另外说明一下,本节内容就是看了作者在知乎发的《LLM的知识蒸馏(KD)应该用Reverse KL?》一文才有想法撰写本节,对于想复现的小伙伴来讲,可以去看这几篇论文的github,作者还给了一些相应的可视化demo。

    unsetunsettrl中的知识蒸馏unsetunset

    TRL(Transformer Reinforcement Learning)库是用于后续训练基础模型的综合库,专为使用监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术进行训练后的基础模型而设计。这里我们只看它里面的两种trainer——SFTtrainer和GKDtrainer。

    从原理方面来讲:

    • SFTTrainer:SFTTrainer 即监督微调训练器,主要是对预训练语言模型进行有监督的微调。它利用给定的输入输出对数据,通过最小化模型输出与真实标签之间的损失,让模型学习到特定任务的模式,将预训练模型适配到具体的下游任务。
    • GKDTrainer:GKDTrainer 是用于知识蒸馏的一种训练器,基于知识蒸馏原理,利用教师模型的知识来指导学生模型的训练,使学生模型学习到教师模型的知识,比如输出分布、特征表示等,以提高学生模型的性能。

    从损失计算方面来讲:

    • SFTTrainer:通常计算模型输出与真实标签之间的交叉熵损失等,衡量模型预测结果与实际标注的差异,通过反向传播来更新模型参数,使模型输出尽可能接近真实标签。
    • GKDTrainer:主要计算学生模型与教师模型输出之间的散度,如 Jensen - Shannon Divergence(JSD)、Kullback - Leibler Divergence(KLD)等,让学生模型学习教师模型的输出分布等知识。

    这两种顺序非常直观,GKDTrainer继承自SFTTrainer,SFTTrainer继承自Trainer。那从SFTtrainer看,它的调用非常简单,trl的readme直接写了一个demo:

    from trl import SFTConfig, SFTTrainer
    from datasets import load_dataset

    dataset = load_dataset("trl-lib/Capybara", split="train")

    training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
    trainer = SFTTrainer(
        args=training_args,
        model="Qwen/Qwen2.5-0.5B",
        train_dataset=dataset,
    )
    trainer.train()

    调用该类后,我又去看了下transformers的trainer,它的损失函数为:

        def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
            """
            How the loss is computed by Trainer. By default, all models return the loss in the first element.

            Subclass and override for custom behavior.
            """

            if (self.label_smoother isnotNoneor self.compute_loss_func isnotNoneand"labels"in inputs:
                labels = inputs.pop("labels")
            else:
                labels = None
            if self.model_accepts_loss_kwargs:
                loss_kwargs = {}
                if num_items_in_batch isnotNone:
                    loss_kwargs["num_items_in_batch"] = num_items_in_batch
                inputs = {**inputs, **loss_kwargs}
            outputs = model(**inputs)
            # Save past state if it exists
            TODO: this needs to be fixed and made cleaner later.
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index]

            if labels isnotNone:
                unwrapped_model = self.accelerator.unwrap_model(model)
                if _is_peft_model(unwrapped_model):
                    model_name = unwrapped_model.base_model.model._get_name()
                else:
                    model_name = unwrapped_model._get_name()
                # User-defined compute_loss function
                if self.compute_loss_func isnotNone:
                    loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
                elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                    loss = self.label_smoother(outputs, labels, shift_labels=True)
                else:
                    loss = self.label_smoother(outputs, labels)
            else:
                if isinstance(outputs, dict) and"loss"notin outputs:
                    raise ValueError(
                        "The model did not return a loss from the inputs, only the following keys: "
                        f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                    )
                # We don't use .loss here since the model may return tuples instead of ModelOutput.
                loss = outputs["loss"if isinstance(outputs, dict) else outputs[0]

            if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
                loss *= self.accelerator.num_processes

            return (loss, outputs) if return_outputs else loss

    很显然这部分有非常多的自适应判断,根据我们上一层为SFTtrainer类,并且没有指定loss方法,所以将选用cross-entropy loss作为模型训练参数。

    而GKDtrainer类的方式就不一样,由于KL散度是不对称的,在知识蒸馏中使用JSD,Jensen-Shannon Divergence 是基于KL散度改进的更平滑和对称的概率分布度量。论文中给出了其改进的计算公式:

    那自然其重写了compute_loss,具体计算为generalized_jsd_loss,代码如下:

        def generalized_jsd_loss(
            student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
        )
    :

            """
            Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
            of https://huggingface.co/papers/2306.13649 for the definition.

            Args:
                student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
                teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
                labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
                beta: Interpolation coefficient between 0 and 1 (default: 0.5)
                temperature: Softmax temperature (default: 1.0)
                reduction: Specifies the reduction to apply to the output (default: 'batchmean')

            Returns:
                loss: Scalar tensor with the generalized JSD loss
            """


            # Apply temperature scaling
            student_logits = student_logits / temperature
            teacher_logits = teacher_logits / temperature

            # Compute log probabilities for student and probabilities for teacher
            student_log_probs = F.log_softmax(student_logits, dim=-1)
            teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)

            # Compute the log of the mixture distribution
            # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
            beta = torch.tensor(beta, dtype=student_log_probs.dtype)
            mixture_log_probs = torch.logsumexp(
                torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
                dim=0,
            )

            # Compute KL divergences using F.kl_div
            # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
            kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
            kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)

            # Compute the Generalized Jensen-Shannon Divergence
            jsd = beta * kl_teacher + (1 - beta) * kl_student

            # Masking
            if labels isnotNone:
                mask = labels != -100
                jsd = jsd[mask]

            # Apply reduction
            if reduction == "batchmean":
                return jsd.sum() / mask.sum() if labels isnotNoneelse jsd.sum() / (jsd.size(0) * jsd.size(1))
            elif reduction == "sum":
                return jsd.sum()
            elif reduction == "mean":
                return jsd.mean()
            else:
                return jsd

        def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
            # compute student output
            outputs_student = model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            )

            # compute teacher output in eval mode
            self.teacher_model.eval()
            with torch.no_grad():
                outputs_teacher = self.teacher_model(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                )

            # slice the logits for the generated tokens using the inputs["prompts"] lengths
            prompt_lengths = inputs["prompts"].shape[1]
            shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
            shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
            shifted_labels = inputs["labels"][:, prompt_lengths:]

            # compute loss
            loss = self.generalized_jsd_loss(
                student_logits=shifted_student_logits,
                teacher_logits=shifted_teacher_logits,
                labels=shifted_labels,
                beta=self.beta,
            )

            # empty cache
            empty_cache()

            # Return loss
            return (loss, outputs_student) if return_outputs else loss

    对于该类好不好用,我也不知道,暂时没用过,只能说从理论来分析,JSD损失和KL损失的区别,不过与SFTtrainer类似,调用方式也很简单,可以跑几次看看情况:

    from datasets import load_dataset
    import random
    from transformers import AutoTokenizer
    from trl import (
        GKDConfig,
        GKDTrainer,
        LogCompletionsCallback,
        ModelConfig,
        ScriptArguments,
        TrlParser,
        get_kbit_device_map,
        get_peft_config,
        get_quantization_config,
    )

    ################
    # Training
    ################
    trainer = GKDTrainer(
        model=model_config.model_name_or_path,
        teacher_model=training_args.teacher_model_name_or_path,
        args=training_args,
        train_dataset=dataset[args.dataset_train_split],
        eval_dataset=test_data,
        processing_class=tokenizer,
        peft_config=get_peft_config(model_config),
    )
    completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
    trainer.add_callback(completions_callback)
    trainer.train()

    # Save
    trainer.save_model(training_args.output_dir)

    lmsys方案思考

    本节是对阳哥夺冠方案中关于蒸馏部分的经典总结,在这里做一个旁征博引,因为没有算力,具体我也没复现过,不过算是除了写这篇推文的初衷,本来是想做一个top方案亮点汇总,只是因为deepseek的爆火针对其中一个方向做了延展。那话不多说,github原址为:https://github.com/shyoulala/LMSYS_BlackPearl

    该仓库的目录结构为:

    ./model_path  # 预训练模型的路径,存放预训练模型的权重和配置文件
    ./src_fast    # 快速训练脚本的存放位置,可能包含简化的训练代码
    ./src         # 完整解决方案的代码目录,包含整个项目的完整训练和处理流程
    ./data        # 数据目录,存放训练数据和其他相关数据
    ./data/oof    # Out-of-Fold 数据目录,可能用于交叉验证的中间结果
    ./data/processed_data  # 处理后的数据目录,存放经过预处理的数据
    ./data/processed_data/orgemma2fold4  # 训练集,包含用于直接蒸馏的 70b 概率数据(第4折)
    ./data/processed_data/orgemma2fold2  # 同上,第2折
    ./data/processed_data/orgemma2fold0  # 同上,第0折
    ./data/processed_data/orgemma2fold1  # 同上,第1折
    ./data/processed_data/orgemma2fold3  # 同上,第3折
    ./data/lmsys-chatbot-arena  # 可能存放与 LMSYS Chatbot Arena 相关的数据或资源
    ./sub         # 输出目录,用于存放训练结果、预测结果等
    ./model_save  # 训练模型的保存路径,存放训练完成后的模型文件
    ./model_save_or  # 另一个模型保存路径,可能是用于存放原始模型或特定版本的模型
    ./model_save_or/v7_ut_gemma_v7_64r128_ddgemma2_16bit  # 经过后处理(如蒸馏)的模型版本,可能是 Gemma2-9B 的 16bit 版本

    挺难想象的,大模型时代竟然还能做交叉验证,不过lmsys是个三分类任务,依照之前逻辑也没什么问题,该方案主要是用llama3-70B和Qwen2-72B-instruct对gamma2-9B做蒸馏,所有大致流程,都通过run_pipeline.sh有显现:

    #!/bin/bash
    set -e

    qwen_path=../model_path/qwen2_72b
    llama_path=../model_path/llama3_70b
    gemma_path=../model_path/Gemma2_9b

    qwen_path_ut=../model_save/qwen2_4bit_pretrain/epoch_0_model/adapter.bin
    llama_path_ut=../model_save/llama3_4bit_pretrain/epoch_0_model/adapter.bin
    gemma_path_ut=../model_save/gemma2_4bit_pretrain/epoch_0_model/adapter.bin


    fold=$1
    echo run:${fold}
    # train llama3 70b
    sh run_fintune.sh llama3 ${llama_path}  ${llama_path_ut} ${fold}
    # predict train logits
    python predict_train.py ${llama_path} ../model_save/llama3_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/llama3fold${fold}/train.parquet ../data/oof/llama3fold${fold}_train.parquet

    # train qwen2 70b
    sh run_fintune.sh qwen2 ${qwen_path}  ${qwen_path_ut} ${fold}
    # predict train logits
    python predict_train.py ${qwen_path} ../model_save/qwen2_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/qwen2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet

    # merge  logits 
    python merge_logits.py ../data/processed_data/gemma2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet ../data/oof/llama3fold${fold}_train.parquet ../data/processed_data/gemma2fold${fold}/train_logits.parquet

    # distill fintune gemma2-9b
    sh run_fintune_16bit_distill.sh gemma2 ${gemma_path} ${gemma_path_ut} ${fold}

    中间几步有挺多有趣的操作,比如是如何做post train的,以及最后merge logits,这里仅谈蒸馏之前的merge lora,因为代码足够简单:

    import time
    from dataclasses import dataclass
    import pickle
    import torch
    import sklearn
    import numpy as np
    import pandas as pd
    from tqdm.auto import tqdm
    from transformers import Gemma2ForSequenceClassification, GemmaTokenizerFast, BitsAndBytesConfig
    from transformers.data.data_collator import pad_without_fast_tokenizer_warning
    from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

    lora_dir = '../model_save/gemma2fold0_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d1 = torch.load(lora_dir)
    lora_dir = '../model_save/gemma2fold1_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d2 = torch.load(lora_dir)
    lora_dir = '../model_save/gemma2fold2_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d3 = torch.load(lora_dir)
    lora_dir = '../model_save/gemma2fold3_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d4 = torch.load(lora_dir)
    lora_dir = '../model_save/gemma2fold4_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d5 = torch.load(lora_dir)

    d = {}
    for k, v in d1.items():
        v = d1[k] + d2[k] + d3[k] + d4[k] + d5[k]
        v = v / 5.
        d[k] = v
    torch.save(d, "../model_save/final_adapter.bin")

    代码上可见,就是对经过5次交叉验证的gamma模型权重做了加权平均合并,但我看discussion很多人提到了,它们同样想到了该方案,不过效果并不好,似乎是这些权重还需要做方差评估,如果方差过大反而会拖累加权后的结果,感兴趣有卡有算力的能进行尝试,我就不过多提了。

    回到正题,最终是先得到了llama3和Qwen的模型输出,那么蒸馏即是需要考虑这两者的结果,所以蒸馏损失选择了:

    loss_fun = nn.CrossEntropyLoss()
    divergence_loss_fn = nn.KLDivLoss(reduction='batchmean')
    cos_loss_fn = nn.CosineEmbeddingLoss()
    outputs = model(batch['input_ids'], use_cache=False# predict gemma2
    logits = outputs.logits
    grads = batch['grads']
    grads1 = batch['grads'][:, :3# qwen2 
    grads2 = batch['grads'][:, 3:] # llama3
    labels = batch['labels']
    loss_ce = loss_fun(logits, labels)
    loss_grad1 = divergence_loss_fn(
        F.log_softmax(logits / T, dim=1),
        F.softmax(grads1 / T, dim=1)
    )
    cos_loss1 = cos_loss_fn(F.softmax(grads1 / T, dim=1), F.softmax(logits / T, dim=1),
                            torch.ones(logits.size()[0]).to(logits.device))

    loss_grad2 = divergence_loss_fn(
        F.log_softmax(logits / T, dim=1),
        F.softmax(grads2 / T, dim=1)
    )
    cos_loss2 = cos_loss_fn(F.softmax(grads2 / T, dim=1), F.softmax(logits / T, dim=1),
                            torch.ones(logits.size()[0]).to(logits.device))

    loss = (loss_ce + loss_grad1 + cos_loss1 + loss_grad2 + cos_loss2) / 5.

    用数学公式理解,即为交叉熵和KL散度的混合:

    这里刚开始我不是很理解,然后问了下deepseek懂了:

    为什么同时使用交叉熵损失和 KL 散度损失?

    1. 保持监督学习能力

    交叉熵损失确保学生模型能够正确预测真实标签,从而保持模型的监督学习能力。如果没有交叉熵损失,学生模型可能会过度依赖教师模型的输出,而忽视真实标签的指导,导致模型在真实数据上的性能下降。

    2. 学习教师模型的软目标

    KL 散度损失让学生模型学习教师模型的软目标,从而捕捉到教师模型的内部表示和知识。软目标通常包含更多的信息,可以帮助学生模型更好地理解数据的分布和特征。

    3. 平衡硬标签和软目标

    同时使用交叉熵损失和 KL 散度损失可以平衡硬标签和软目标的贡献。硬标签(真实标签)提供了直接的监督信号,而软目标(教师模型的输出)提供了更多的上下文信息。通过调整两者的权重,可以更好地指导学生模型的学习。

    其实我认为以上主要的,是因为教师模型是两个,而不是一个,KL更适合于一个,而两个加入交叉熵我的理解为桥接,更能体现泛化,但具体为啥这样安排,只有跑了才知道,所以根据github的环境说明,有8张A100以上的,可以跑一轮,等待3天以上,观看结果了。

    open-r1中的蒸馏

    该repo是DeepSeek-R1的开放复现版本,由huggingface的CEO亲自提出并进行,我大致看了一下,它的规划是:

    • 步骤 1:从 DeepSeek-R1 中提取高质量语料库来复制 R1-Distill 模型。
    • 步骤 2:复制 DeepSeek 用于创建 R1-Zero 的纯 RL 管道。这可能涉及为数学、推理和代码整理新的大规模数据集。
    • 步骤 3:展示我们可以通过多阶段训练从基础模型转向 RL 调整。

    这里重点看step 1,即它使用distilabel来对Deepseek-R1提取蒸馏数据,以下是一个简单demo:

    from datasets import load_dataset
    from distilabel.models import vLLM
    from distilabel.pipeline import Pipeline
    from distilabel.steps.tasks import TextGeneration


    prompt_template = """\
    You will be given a problem. Please reason step by step, and put your final answer within \boxed{}:
    {{ instruction }}"""


    dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train").select(range(10))

    model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"# Exchange with another smol distilled r1

    with Pipeline(
        name="distill-qwen-7b-r1",
        description="A pipeline to generate data from a distilled r1 model",
    as pipeline:

        llm = vLLM(
            model=model_id,
            tokenizer=model_id,
            extra_kwargs={
                "tensor_parallel_size"1,
                "max_model_len"8192,
            },
            generation_kwargs={
                "temperature"0.6,
                "max_new_tokens"8192,
            },
        )
        prompt_column = "problem"
        text_generation = TextGeneration(
            llm=llm, 
            template=prompt_template,
            num_generations=4,
            input_mappings={"instruction": prompt_column} if prompt_column isnotNoneelse {}
        )


    if __name__ == "__main__":
        distiset = pipeline.run(dataset=dataset)
        distiset.push_to_hub(repo_id="username/numina-deepseek-r1-qwen-7b")

    然后将该数据加入了sft中:

    def main(script_args, training_args, model_args):
        ################
        # Model init kwargs & Tokenizer
        ################
        quantization_config = get_quantization_config(model_args)
        model_kwargs = dict(
            revision=model_args.model_revision,
            trust_remote_code=model_args.trust_remote_code,
            attn_implementation=model_args.attn_implementation,
            torch_dtype=model_args.torch_dtype,
            use_cache=Falseif training_args.gradient_checkpointing elseTrue,
            device_map=get_kbit_device_map() if quantization_config isnotNoneelseNone,
            quantization_config=quantization_config,
        )
        training_args.model_init_kwargs = model_kwargs
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
        )
        tokenizer.pad_token = tokenizer.eos_token

        ################
        # Dataset
        ################
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

        ################
        # Training
        ################
        trainer = SFTTrainer(
            model=model_args.model_name_or_path,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no"elseNone,
            processing_class=tokenizer,
            peft_config=get_peft_config(model_args),
        )

        trainer.train()

        # Save and push to hub
        trainer.save_model(training_args.output_dir)
        if training_args.push_to_hub:
            trainer.push_to_hub(dataset_name=script_args.dataset_name)


    if __name__ == "__main__":
        parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
        script_args, training_args, model_args = parser.parse_args_and_config()
        main(script_args, training_args, model_args)

    从代码上可以看到,这个过程是从教师模型中提取知识,并将其传递给学生模型。在这个特定的情况下,知识不是以软标签的形式直接传递,而是通过生成的推理数据来传递。这种方法通常被称为数据蒸馏(Data Distillation)或示例蒸馏(Example Distillation),它是知识蒸馏的一种变体。

    最后

    我看到了腾讯科技发布的一场关于DeepSeek的高质量闭门会:比技术更重要的是愿景 ,里面的很多内容可以作为结尾:

    1. 长期来说,通过走捷径的方式,而没有自己通过愿景去想怎么做技术方案,而是直接复现,中间可能会有不知道的坑。比如在这一代技术 long context 没有质变的前提下,解决问题的上限可能会被限制。R1-zero 可能是一个正确的方向,从头就做 R1-zero 或不通过类 o1 的数据启动可能更好。照着别人的技术方案可能不太好,希望更多探索。
    2. 蒸馏的坏处是模型 diversity 下降,影响模型上限,无法超越最强的模型。但短期看,蒸馏也是一条路线。其他模型用蒸馏也能得到较好的结果,未来在模型生态里面可能就会有老师、学生的角色区分,有能力当一名好学生也是一种可以的商业模式。

    文章的最后,因为deepseek火爆的出圈,我也看到了很多各类博主在其上做各种任务,比如写文章或者写诗,其中有一首我很喜欢的,是它们都比deepseek 好,我知道 一文中使用deepseek生成的其一,作为本篇结尾:

    春城惊岁晚,梅魂初醒,滇海骤翻银浪。西山素甲,南天冻幕,翠湖暗锁寒香。冰绡裹垂杨,讶螺峰披絮,金马凝霜。万户笙箫,尽收檐角作琳琅。

    谁教玉戏蛮乡?遣滕六醉舞,姑射颠狂。谢女絮迷,袁安户掩,争知南诏风光。椒盘冷红妆,想罗裙冰透,画阁炉藏。且待明朝晴暖,花事又铺张。


    53AI,企业落地大模型首选服务商

    产品:场景落地咨询+大模型应用平台+行业解决方案

    承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业

    联系我们

    售前咨询
    186 6662 7370
    预约演示
    185 8882 0121

    微信扫码

    与创始人交个朋友

    回到顶部

     
    扫码咨询