AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


剪枝与蒸馏的最佳实践
发布日期:2024-09-25 14:42:18 浏览次数: 1783 来源:大魏分享


一、剪枝与蒸馏的意义

从头开始预训练(pretrain)SLMs并不总是可行的,因为在数据收集、预训练(pretraining)管道等方面存在重大挑战。一种流行的替代方法是从较大的LLMs开始,并将其蒸馏(distill)为较小的模型。剪枝(pruning)和蒸馏(distillation)是该领域最受欢迎的两种技术。最近,NVIDIA发布了两个基于Llama 3.1–450B蒸馏(distilled)版本的模型,分别是Minitron-8B和Minitron-4B。

Minitron专注于通过剪枝(pruning)和蒸馏(distillation)来减少AI模型的大小,使其在不牺牲太多准确性的情况下更加高效。

  • 剪枝(pruning)通过切割层(深度剪枝,depth pruning)或移除神经元、注意力头(attention heads)或嵌入通道(embedding channels)来减少模型的大小。为了恢复一些丢失的准确性,剪枝(pruning)后通常需要重新训练。

  • 蒸馏(distillation)是一种相关技术,其中较小的模型(称为学生模型,student model)从较大、复杂的模型(称为教师模型,teacher model)中学习。目标是创建一个更紧凑的模型,保留较大模型的大部分预测能力,同时更快且对资源要求更少。


二、蒸馏方法:SDG微调经典知识蒸馏
Minitron确定了两种关键的蒸馏(distillation)风格。一种方法是SDG微调(fine-tuning),其中较小的预训练(pretrained)学生模型(student model)使用由较大教师模型(teacher model)生成的数据进行细化。在这种方法中,学生模型(student model)模仿教师模型(teacher model)预测的最终标记,如一些流行的教程和AI平台所示。

另一种方法,经典知识蒸馏(classical knowledge distillation),更为复杂。学生模型(student model)不仅关注预测的标记,还尝试复制教师模型(teacher model)的各种内部状态。这种技术在训练过程中提供了更详细的反馈,从而提高了准确性。然而,实施这种方法需要在训练框架中提供特定支持,因为它涉及处理来自教师内部状态的大量数据。

这两种方法并不互斥,可以相辅相成。Minitron主要强调经典知识蒸馏(classical knowledge distillation)方法。

三、剪枝(pruning)和蒸馏(distillation)工作流程
为了创建更高效的模型,Minitron将剪枝(pruning)与经典知识蒸馏(classical knowledge distillation)相结合。从较大的模型(如15B参数模型)开始,Minitron评估不同组件(层、神经元等)的重要性,然后将模型缩小到较小的尺寸,如8B模型。较小的模型经过轻量级的重新训练过程,从原始较大模型中学习。这个过程可以重复进行,以进一步减少模型大小,最终生成更小的版本,如4B模型。

剪枝(pruning)和蒸馏(distillation)过程是迭代的,每个较小的模型作为下一轮压缩和重新训练的基础。

参考上图,步骤解释如下:

这张图片展示了如何通过剪枝(pruning)和蒸馏(distillation)技术将一个大型语言模型(LLM)逐步缩小为更小、更高效的模型。让我们一步一步地解释这个过程:

1. 训练好的大型语言模型(Trained LLM)

首先,我们有一个已经训练好的大型语言模型(LLM),它包含了很多层(Layer),每一层都有不同的组件,比如嵌入(Embedding)、变压器块(Transformer Block)、注意力机制(Attention)、层归一化(Layer Norm)和多层感知机(MLP)。

2. 估计重要性(Estimate Importance)

接下来,我们需要评估每个组件的重要性。我们通过前向传播(forward propagation)来记录每个组件的活动情况,然后计算它们对模型整体性能的贡献。这一步帮助我们确定哪些组件是最重要的,哪些可以被移除。

3. 排序(Rank)

根据重要性评估的结果,我们对组件进行排序。重要性高的组件排在前面,重要性低的排在后面。这一步是为了在剪枝时有一个明确的优先级。

4. 剪枝(Trim)

在这一步,我们根据排序结果移除那些重要性较低的组件。比如,我们可能会移除一些嵌入(Embedding)、注意力头(Attention Heads)和嵌入通道(Embedding Channels)。这样,模型的大小就会减小,但仍然保留了大部分重要的功能。

5. 蒸馏(Distillation)

剪枝后,我们对模型进行蒸馏。蒸馏的过程是让一个较小的学生模型(Student Model)从较大的教师模型(Teacher Model)中学习。通过这种方式,学生模型可以保留教师模型的大部分预测能力,但体积更小,运行更快。

具体例子:从15B到8B,再到4B

图片的下半部分展示了具体的例子:

  • 从15B到8B:我们从一个15B参数的模型开始,经过上述步骤,剪枝和蒸馏后得到一个8B参数的模型。

  • 从8B到4B:然后,我们再从8B参数的模型开始,重复同样的步骤,最终得到一个4B参数的模型。

    通过这种方法,我们可以逐步将一个非常大的模型缩小为更小、更高效的版本,同时尽量保持其性能。

四、剪枝(pruning)影响
有效剪枝(pruning)模型需要了解其哪些部分是必不可少的。Minitron使用基于激活数据的方法来估计各种组件(层、神经元、注意力头(attention heads)和嵌入通道(embedding channels))的重要性,使用一个小数据集。这个方法只需要前向传播,比依赖反向传播和梯度计算的技术更简单且更具成本效益。

虽然可以在模型的不同部分之间交替进行剪枝(pruning)和重要性估计,但Minitron发现大多数情况下单轮重要性估计就足够了。


Minitron将这些技术应用于Llama 3.1模型家族,其中包括从405B到8B参数的模型。具体来说,他们专注于将8B模型蒸馏(distill)为更高效的4B版本。

微调教师模型
在剪枝(pruning)之前,Minitron微调了8B模型,以应对原始训练集数据分布的变化。如果没有这一步,教师模型(teacher model)可能无法在蒸馏(distillation)过程中为学生模型(student model)提供最佳指导。

深度剪枝(depth pruning)
为了将8B模型减少到4B,Minitron剪除了16层,通过逐一移除并跟踪性能影响来评估其重要性。他们发现模型开头和结尾的层对保持准确性最为关键。基于此分析,Minitron为最终的4B模型移除了特定的一组层。

宽度剪枝(width pruning)
除了深度剪枝(depth pruning),Minitron还沿宽度维度进行剪枝(pruning),目标是注意力头(attention heads)、嵌入通道(embedding channels)和隐藏层。剪枝(pruning)后,重新训练有助于恢复初始剪枝步骤中丢失的一些性能。有趣的是,尽管宽度剪枝(width pruning)最初导致的损失高于深度剪枝(depth pruning),但重新训练使模型随着时间的推移更有效地恢复。
五、使用经典知识蒸馏(classical knowledge distillation)进行重新训练
剪枝(pruning)后,Minitron使用经典知识蒸馏(classical knowledge distillation)重新训练较小的模型。这涉及通过在模型的各个阶段(包括嵌入输出、logits和变压器架构中的特定损失)最小化损失来教导剪枝后的模型。学生模型(student model)通过比较不同层的输出从未剪枝的教师模型(teacher model)中学习。

这张图展示了教师模型(Teacher Model)和学生模型(Student Model)在知识蒸馏(knowledge distillation)过程中的对比和学习关系。通过这个过程,学生模型从教师模型中学习,以便在保持性能的同时变得更小、更高效。让我们一步一步地解释这张图:

教师模型(Teacher Model)

  • 嵌入(Embeddings):输入数据首先通过嵌入层,将原始数据转换为适合模型处理的向量表示。

  • 多头注意力(Multi-Head Attention):嵌入向量经过多个注意力头,捕捉输入数据中的不同特征。

  • 前馈神经网络(Feed Forward, MLP):注意力机制后的输出经过前馈神经网络进行进一步处理。

  • 语言模型头(LM Head):最终的输出经过语言模型头,生成预测结果(logits)。

学生模型(Student Model)

学生模型的结构与教师模型类似,但通常规模较小:

  • 嵌入(Embeddings):输入数据同样通过嵌入层。

  • 多头注意力(Multi-Head Attention):嵌入向量经过多个注意力头。

  • 前馈神经网络(Feed Forward, MLP):注意力机制后的输出经过前馈神经网络。

  • 语言模型头(LM Head):最终的输出经过语言模型头,生成预测结果(logits)。

损失计算(Loss Calculation)

在知识蒸馏过程中,学生模型通过多种损失函数从教师模型中学习:

  1. 嵌入输出损失(Embedding Output Loss):学生模型的嵌入输出与教师模型的嵌入输出进行比较,计算损失。

  2. MLP输入损失(MLP Input Loss):学生模型的前馈神经网络输入与教师模型的前馈神经网络输入进行比较,计算损失。

  3. 编码器块输出损失(ENC Block Output Loss):学生模型的编码器块输出与教师模型的编码器块输出进行比较,计算损失。

  4. 语言模型头损失(LM Head Loss):学生模型的语言模型头输出与教师模型的语言模型头输出进行比较,计算损失。

  5. logits损失(Logit Loss):学生模型的最终预测结果(logits)与教师模型的最终预测结果进行比较,计算损失。

    通过这些损失函数,学生模型可以在多个层次上模仿教师模型,从而在保持性能的同时变得更小、更高效。


通过广泛的实验,Minitron总结了压缩语言模型的几个最佳实践:
· 模型大小:从训练最大的模型开始,然后逐步剪枝(pruning)和蒸馏(distillation)以创建较小的版本。
· 剪枝策略:对于高达15B参数的模型,重点关注宽度剪枝(width pruning)而非深度剪枝(depth pruning)。单次重要性估计通常足够。
· 重新训练:使用蒸馏损失(distillation loss)而非传统训练进行重新训练。当显著剪枝层时,结合logits、中间状态和嵌入的损失。对于较小的深度减少,仅使用logit蒸馏(logit-only distillation)。


五、单次重要性估计(single-shot importance estimation)

它是一种在剪枝(pruning)过程中使用的方法,用来确定模型中哪些部分可以被移除,而不会对模型的性能造成太大影响。这个方法的特点是只需要进行一次评估,而不是多次反复评估。

具体步骤如下:

  1. 前向传播(Forward Propagation):我们用一个小数据集让模型运行一次,记录下模型中每个部分(比如层、神经元、注意力头(attention heads)、嵌入通道(embedding channels)等)的活动情况。

  2. 重要性评估(Importance Estimation):根据记录下来的活动情况,计算每个部分的重要性分数。这些分数告诉我们每个部分对模型整体性能的贡献有多大。

  3. 剪枝(Pruning):根据重要性分数,移除那些分数较低、对模型性能贡献较小的部分。

    相比需要多次评估和调整的方法,单次重要性估计只需要进行一次全面的评估,因此更省时间和计算资源。虽然这种方法可能不如多次评估的方法那么精确,但在大多数情况下,它已经足够有效,能够显著减少模型的大小,同时保持较高的性能。

六、结果

NVIDIA在多个基准上评估了Minitron模型,结果与基线模型的性能相匹配。

Minitron模型在多个基准测试中的表现显示出其在保持较高精度的同时显著减少了模型的大小。具体来说,Minitron 4B在宽度剪枝(Width-pruned)和深度剪枝(Depth-pruned)版本中,尽管参数量减少,但在大多数基准测试中的精度仍然接近或优于Llama-3.1 8B。例如,在winogrande和MMLU测试中,Minitron 4B的宽度剪枝版本分别达到了0.7403和0.5860的准确性,接近Llama-3.1 8B的0.7727和0.6528。此外,Llama-3.1-Minitron 4B(宽度剪枝)在hellaswag和arc_challenge测试中表现尤为突出,分别达到了0.7606和0.5555的标准化准确性,甚至超过了Llama-3.1 8B。这表明,通过剪枝和蒸馏技术,Minitron模型在显著减少参数量的同时,仍能保持较高的性能。

参考:https://pub.towardsai.net/how-nvidia-pruned-and-distilled-llama-3-1-to-create-minitron-4b-and-8b-6646d42c92c6



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询