大模型高效训练一体框架 LLaMA Factory
发布日期:2025-01-06 18:12:25
浏览次数: 1551
来源:DataFunSummit
导读 本文将分享如何基于 LLaMA Factory 实现大模型高效训练。
1. 大模型低资源训练技术总览
2. LLaMA Factory 整体介绍
3. LLaMA Factory 快速入门
4. LLaMA Factory 性能调优
5. LLaMA Factory 应用案例
6. Q&A
分享嘉宾|郑耀威 北京航空航天大学 研究员
编辑整理|王红雨
内容校对|李瑶
出品社区|DataFun
大模型低资源训练技术总览
在资源有限的情况下,实现高效训练大语言模型变得极具挑战。
以 LLaMA-3 8B 为例,其采用纯解码器结构,包含 32 层 Decoder 单元,每层由 Attention 和 MLP 两个核心组件构成,本质上是深度神经网络,隐藏层大小分别为 4096 维和 14336 维。因为维度较大,在前向传播和反向计算梯度时就需要较多的显存。GQA 查询注意力头数和键值注意力头数分别为 32 和 8。词表大小提升到了 128000,词表增大意味着可以编码更多 token,但随之也会增大显存的使用。以上 6 个数字就定义出了模型的显存结构。
另外,在训练时还会有另外两个数字,一个是批处理大小,假设值为 4,也就是每次喂给模型 4 个样本,并假设这些样本的序列长度为 8192。
通过以上数字就可以估算出模型训练时的显存占用。首先,可以计算出参数数量为 80 亿个。除了参数占用显存,还要存一些中间计算结果,也会占用显存,这部分又分为中间激活值、注意力分数和模型输出三类。如上图所示,这四类数字加起来已超过千亿,甚至可达到万亿级别,就需要上百 GB 的显存。
将以上数字转变为字节数,以 bfloat16 的精度存储模型权重,这样模型权重需占用显存约为 16GB。除此之外,模型参数、梯度和优化器状态一共会占用约 128GB 的显存,而激活值则需要 764GB。总共加起来需要近 1000GB 的显存,供 8B 模型训练。这一数量是不可接受的,因此需要一些优化策略来降低显存占用。
- 分块:将显存占用分摊到多个 GPU 上,可以使用 ZeRO3 或 FSDP 策略,或采用张量并行的方式。
- 卸载:放到 CPU 上存储,用的时候再拿到 GPU,这样可以将 GPU 显存占用降到很小,但可能会影响训练速度。
除了参数优化,还有梯度和优化器的优化。首先也可以做量化,使用 8-bit 优化器,可以将显存占用降低到原来的四分之一左右。最流行的方式是采用 LoRA 方法,即降低梯度和优化器的显存,采用矩阵分解的方式,使用低秩矩阵来模拟高秩矩阵,从而将显存占用降至原来的百分之一甚至更小,优化效果显著。类似的还有 Facebook 提出的 GaLore 方法。除此之外,还可以采用 Sampling 的方式,每次只优化一部分参数,主要算法包括 LOMO、Spectrum 等,其中 Badam 效果也非常好。梯度和优化器也可以分块,分摊至多张显卡,可以利用 ZeRO2 或 FSDP 策略,也可以使用张量并行。
最后再来讲一下激活值的优化方式。首先,分块计算,利用 Flash Attention 组件,可以将注意力分数优化至接近于 0 的大小。还可以利用 Fused CE 来对概率分布进行优化。重计算的另一个算法是 Checkpointing,其核心思路是时间换空间。最后一个方法是将激活值卸载到 CPU 上,可以进一步将显存占用从 bsh(l+1)降低至 2bsh。
假设模型参数采用量化策略,梯度/优化器采用 LoRA 策略,激活值采用上述所有方法,优化后总共仅用不到 10GB 的显存就可以完成 LLaMA-3 8B 模型的训练。可见上述方法的优化效果非常显著,然而要实现这些优化也并非易事,涉及复杂的编码和操作。因此我们提出了 LLaMA-Factory 框架。
LLaMA-Factory 整体介绍
LLaMA-Factory 集成了众多优化方法,使用时不再需要繁琐的代码编写,仅需在网页上操作即可实现算法调用。
框架结合了前面提到的 Flash Attention、GPTQ 作为核心算子,还集成了 LoRA 及其各种变体。基于这些算法,提供了大模型预训练、SFT、RLHF、DPO、SimPO 等优化策略,可以运行于 NVIDIA、Ascend、AMD 以及 Mac 等硬件之上。核心产品名为 LLaMA Board,通过可视化界面可实现零代码的模型微调。
LLaMA
Board 提供了四种语言,支持 300 多种模型,包括单模态的 LLaMA、Qwen,以及多模态的 LLaVA、Qwen-VL 等。平台集训练、评估于一体,提供了训练、评估、对话以及导出四种模式。用户可以根据需要选择数据集,并可视化地配置参数。可一键启动或暂停训练,启动后可实时观察训练损失曲线和训练进度。LLaMA Board 是完全开源的,可以从 GitHub 上下载,也可以看到源代码以便于进行定制开发。
在今年一年的时间中,LLaMA 不断升级,一共发布了 5 个大版本,加入了更多模型,适配了更多设备,并不断提升训练速度。
LLaMA-Factory 获得了社区非常的积极反馈,在 GitHub 已收获逾 35000 颗星标,成为微调框架中热门项目。并且被广泛应用于英伟达、亚马逊 AWS、腾讯云、阿里云等知名产品中,在业界受到了高度认可。已汇聚超 100 名贡献者,共同推进框架完善与功能丰富。
LLaMA-Factory 快速入门
大家通常会选择从指令模型开始,进行指令微调(SFT)。指令微调是初学者容易掌握的训练方法,适合那些刚开始接触大模型训练的新手。可以采用开源数据集,如 OpenHermes、AlignAnything、InfinityInstruct,训练模型,使其能够遵从人类指令,进行多轮对话,或执行复杂任务,拓展原有模型的功能边界。
对于高阶用户,也可以不使用 Web UI,而是通过命令行来进行操作。
开始训练之前需要关注数据构建,包括样本的收集、合成和增广。可以利用大模型去合成数据,比如将原来的非多轮对话数据改写为多轮对话。
根据任务难度选取模型,难度较小,比如只是做客服,就可以用 2B、7B 的模型,如果任务很难,就要选择较大的模型。另外,算法的选择要参考硬件资源情况,比如在硬件有限的情况下可以选择 LoRA 的方法,而硬件充足时则可以选择全量微调。
评估测试时,可以利用 LLaMA-Factory 自带的评估样例,也可以进行人工评测,来验证模型效果。
LLaMA-Factory 性能调优
在 LLaMA-Factory 的最新版本中,引入了高性能算子,来实现低显存上更高速度的训练。使用--enable_liger_kernel 以及--use_unsloth_gc 命令,就可以调用 Liger SwiGLU 和 Liger RMS Norm 算子。
从上图中可以看到,使用上述算子后,可以将样本最大长度从 4k 扩展到 32k,LoRA 微调样本甚至可以扩展到 64k。使用一张 40GB 的显卡就可以训练一个 8B 64k 的模型。
同时,还可以提升硬件利用率,从平均的 33% 升至 60%,节省了一半的训练时间,优化明显。
LLaMA-Factory 应用案例
最后通过实际案例来说明如何使用 LLaMA-Factory 框架。
例如,实现一个 AI 导游。首先数据构造阶段,从互联网中抓取了一些图文对,并通过模版构造出一些基础问答对。进一步需要做数据增强,基于最初构造出的几十条数据进行扩充,以便更好地训练模型。我们调用大模型合成风格化数据,让大模型扮演不同的角色,重写数据,这样得到了几百条数据。
我们利用这几百条数据去做 Qwen2-VL 2B 的全参微调。上图中可以看到,微调的损失函数是平缓下降的,可以拟合到一个非常低的水平。
上图中可以看到,原始模型给出的是一个错误答案,而微调后模型可以做出正确回答,并给出相关介绍。关于幻觉问题,建议在数据构建时混合一些通用样本数据,有助于提升模型准确率。
Q&A
A1:在 0.8 版本中已加入多卡训练功能,支持 GPU、NPU 及 AMD 设备,可在界面选择多卡训练。特别针对华为 920B3 完美适配,轻松安装无需额外配置即可使用。推荐使用提供的 Docker 镜像建立计算环境。
A2:Agent 微调数据应考虑多轮对话的不同角色,如 User Assistant、Function Call、Observation 等,可参照 Readme 文档中指定的格式构建。
Q3:是否使用 Deep Speed 进行底层训练?
A3:当前主要采用 Deep
Speed 与 FSDP 方式,以确保更广泛的模型兼容性。模型并行通过 Deep Speed ZeRO3 实现多显卡分配,已完成对 Deep Speed ZeRO3 的良好适配。
A4:Web UI 方式支持多卡训练,但对于多节点训练更推荐使用命令行,更为稳定,但 Web UI 环境下配置好环境变量也可实现。
A5:完全可行,可选用较智能的 GPT 模型合成样本;微调时可以选择开源小型模型,降低成本的同时保持良好效果。
A6:JSON 格式规范化数据结构,尽管需额外步骤,但确保数据一致性与模型训练的标准化。
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业