微信扫码
与创始人交个朋友
我要投稿
XTuner[1] 是一个高效、灵活、全能的轻量化大模型微调工具库。与LLaMA-Factory类似,不过从官方的文档来看,在长序列训练、token生成速度等方面要比LLaMA-Factory更强。
从数据集来看,LLaMA-Factory支持多种格式的数据集,更通用泛化;而XTuner只支持类似ShareGPT格式的数据集。
从模型支持来看,LLaMA-Factory支持的模型种类也要比XTuner更多;但XTuner多模态模型(LLaVA-Internlm2-7B[2] / 20B[3]、LLaVA-v1.5[4])的支持要比LLaMA-Factory。
不过我个人比较关注的是多轮对话训练时的loss计算。从文档[5]来看,XTuner更清晰,而且是我想要的效果;而对于LLaMA-Factory,其放出来的只是数据集格式文档,而且之前我提过issue[6]想要了解其Loss计算逻辑,回复也是不清晰,相比较来看,LLaMA-Factory的loss计算没那么透明,只能啃源码。
另一方面就是多轮对话所对应的长序列训练性能。随着 Gemini 1M context length 和 Sora 出世,如何训练超长上下文的大模型引起了大家广泛关注。同时在大多数的场景下,多轮对话一般也就是一个conversations包含几轮对话;但在我实际遇到的情况中,一个conversations下有几百个对话,即长对话,这种场景还是比较多的。当然也有解决方案,就是比较麻烦,需要做拆分;在基座模型支持长上下文的情况下,如果微调框架能支持长序列训练,且性能不错,是很好的选择;XTuner在这方面要比LLaMA-Factory更好。见:支持序列并行训练策略以实现语言模型超长上下文训练![文档[7]] [速度基准[8]]。
在心理医疗方面,多轮对话产生的长序列更长久,尤其是心理咨询重的医患问询链内容。
大语言模型 Supervised Finetune(SFT)旨在通过有监督的微调来提高预训练模型在特定任务上的性能。为支持尽可能多的下游任务,XTuner 支持了增量预训练、单轮对话、多轮对话三种数据集格式。
在指令微调阶段,我们的目标是训练语言模型根据人类指令给出回答。 因此,一般只有回答部分(Output)的 loss 会用于梯度回传,而指令部分(System、Input)部分的 loss 则不会用于权重更新。 基于此,我们在对数据集进行预处理的时候引入了 "system"、"input" 和 "output" 三个字段,"system"、"input" 字段用于保存不需要计算 loss 的文本,例如系统或用户指令,而 "output" 字段则用于保存需要计算 loss 的文本,例如输入指令对应的 GroundTruth 回答。
为了统一增量预训练、单轮对话和多轮对话三种数据集格式,我们将数据集格式设置为以下形式:
[{
"conversation":[
{
"system": "xxx",
"input": "xxx",
"output": "xxx"
}
]
},
{
"conversation":[
{
"system": "xxx",
"input": "xxx",
"output": "xxx"
},
{
"input": "xxx",
"output": "xxx"
}
]
}]
在训练过程中,我们会将一条数据中的多组 "system"、"input" 和 "output" 进行拼接,之后输入模型,并行计算每个位置的 loss ,但只有 "output" 部分对应的 loss 参与梯度回传,如下图所示。
其中
由于增量预训练旨在帮助模型学习针对特定下游任务的语言知识和表达能力,因此数据集的全部内容对应的 loss 都应该用于梯度回传。因此,数据集的 "system"、"input" 为空,而 "output" 为一整条语料数据。增量预训练任务对应的数据集格式如下所示:
[{
"conversation":[
{
"system": "",
"input": "",
"output": "I am an artificial intelligence (AI) assistant named Puyu. I was created by the Shanghai AI Laboratory and my purpose is to assist users with various tasks through natural language processing technology."
}
]
},
{
"conversation":[
{
"system": "",
"input": "",
"output": "I am an artificial intelligence programmed to assist with various types of tasks, including answering questions, providing information, and performing automated processes."
}
]
}]
单轮对话数据集往往由一条指令(或问题)及其对应 GroundTruth 回答组成。由于只有回答部分需要对 loss 进行回传,因此数据集的 "system"、"input" 字段为输入指令,"output" 字段为对应回答。单轮对话数据集格式如下所示:
[{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "Give three tips for staying healthy.",
"output": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."
}
]
},
{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "How to study English?",
"output": "1. Set clear goals. 2. Create a study plan. 3. Build vocabulary. 4. Practice speaking."
}
]
}]
多轮对话数据集往往由多轮指令(或问题)+ 对应 GroundTruth 回答组成。假设我们现在有一条多轮对话数据,内容如下。
为方便介绍,对于第 n 轮对话,我们将 User 和 Assistant 对应的输出设为 UserN 和 AssistantN。
System: You are an AI asssistant.
User1:Hello?
Assistant1:Hello! How can I help you?
User2:What's the date today?
Assistant2:Today is Monday, August 14, 2023.
User3:Thank you!
Assistant3:You are welcome.
如何使用上述这条多轮对话数据训练大模型?目前有以下两个主流方法。
System、User1、Assistant1、User2、Assistant2、User3的文本都视为模型的输入部分,将 Assistant3 的文本视为模型的预测部分,只有 Assistant3 部分的 loss 参与权重更新。
这种方法的弊端在于没有充分利用多轮对话的训练数据,因为 Assistant1 和 Assistant2 的内容没有参与模型训练,导致训练数据利用率较低。
将一条多轮对话数据,拆分成多条数据。例如将以上示例拆分成如下三条数据。
相比于方法1,方法2可以充分利用每一轮对话的数据,但需要将一条包含 n 轮对话的数据拆分为 n 条数据,训练效率降低 1/n。
XTuner 训练多轮对话模型时,采取了一种更加充分高效的方法,如下图所示。
我们将多轮对话进行拼接,之后输入模型,并行计算每个位置的 loss,而只有 Output 部分的 loss 参与回传。因此 XTuner 中多轮对话数据集格式如下所示:
[{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "Hello?",
"output": "Hello! How can I help you?"
},
{
"input": "What's the date today?",
"output": "Today is Monday, August 14, 2023."
},
{
"input": "Thank you!",
"output": "You are welcome."
}
]
},
{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "Hello?",
"output": "Hello! How can I help you?"
},
{
"input": "How's the weather today in Rosso?",
"output": "The weather in Rosso on Wednesday, August 16th, is going to be cloudy for most of the day, together with moderate rain around noon."
},
{
"input": "Thank you!",
"output": "You are welcome."
}
]
}]
数据集中的 "conversation" 键对应的值是一个列表,用于保存每一轮对话的指令和实际回答(GroundTruth)。为了保持格式统一,增量预训练数据集和单轮对话数据集中的 "conversation" 键也对应一个列表,只不过该列表的长度为 1。而在多轮对话数据集中,"conversation" 列表的长度为 n,以容纳 n 轮的对话内容。
官方效果图如下:
随着生成性 AI 的不断发展,长序列训练正在变得非常重要。具有长上下文能力的大模型开始逐渐取代 RAG 成为信息检索的重要解决方案。代码库理解和例如 Sora 这种视频生成任务都需要在空间和时间层面对长上下文进行推理,而Gemini已经支持1M上下文的输入。
相比较8K、32K上下文的模型输入支持,长序列输入在以后的模型中一定是发展的方向。
XTuner 中的序列并行设计思路参考了 DeepSpeed 的工作 DeepSpeed Ulysses[9],并加以优化,以达到直接基于 transformers 算法库或 Huggingface Hub 上的开源模型训练 1M 以上超长序列的目标。Ulysses结构设计如下:
对于XTuner更细节的长序列支持的实现,可以查看原文。见文档[10]。
模型 | 序列并行支持情况 |
---|---|
baichuan 1/2 | ❌ |
chatglm 2/3 | ❌ |
deepseek | ✅ |
gemma | ❌ |
internlm 2 | ✅ |
llama 2 | ✅ |
mistral | ❌ |
qwen 1/1.5 | ❌ |
starcoder | ❌ |
yi | ✅ |
zephyr | ✅ |
因此对于需要用到长序列训练的场景,要考虑基座模型的选择。XTuner单独有个parallel[11]包封装了长序列并行训练的逻辑。而且提供了API给其它repo调用,抽离了整个长序列并行训练的逻辑。
后续,我个人也会基于XTuner来做微调训练,尤其是验证长序列的训练结果。总的来说,XTuner的文档与技术,都是很令人满意的。
XTuner: https://github.com/InternLM/xtuner/blob/docs/README_zh-CN.md
[2]LLaVA-Internlm2-7B: https://huggingface.co/xtuner/llava-internlm2-7b
[3]20B: https://huggingface.co/xtuner/llava-internlm2-20b
[4]LLaVA-v1.5: https://github.com/haotian-liu/LLaVA
[5]文档: https://github.com/InternLM/xtuner/blob/main/docs/zh_cn/user_guides/dataset_format.md
[6]issue: https://github.com/hiyouga/LLaMA-Factory/issues/3729
[7][文档: https://github.com/InternLM/xtuner/blob/docs/docs/zh_cn/acceleration/train_extreme_long_sequence.rst
[8][速度基准: https://github.com/InternLM/xtuner/blob/docs/docs/zh_cn/acceleration/benchmark.rst
[9]DeepSpeed Ulysses: https://arxiv.org/abs/2309.14509
[10]文档: https://github.com/InternLM/xtuner/blob/docs/docs/zh_cn/acceleration/train_extreme_long_sequence.rst
[11]parallel: https://github.com/InternLM/xtuner/blob/docs/xtuner/parallel/sequence/attention.py
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-08-13
2024-03-30
2024-05-28
2024-05-10
2024-04-26
2024-04-12
2024-04-25
2024-07-25
2024-05-06
2024-07-18
2025-01-22
2025-01-22
2025-01-22
2025-01-22
2025-01-21
2025-01-21
2025-01-20
2025-01-18