AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


聊聊Llama-Factory微调之loss计算与上下文长度
发布日期:2024-07-31 07:19:54 浏览次数: 2132 来源:阿郎小哥的随笔驿站


背景

github issue:关于多轮对话的loss计算[1]

该issue是之前我提给Llama-Factory的,主要是想了解下该框架微调时的Loss计算逻辑,其实就是mask的排列。一般使用Llama-Factory直接微调就完了,不需要也不会在意其内在的逻辑;但我是因为使用了Llama-Factory微调训练后,效果很差,才去了解其逻辑。我个人觉得,除却微调数据集的格式与质量外,还有两个需要关注的因素:上下文长度与Loss计算。

上下文长度

模型的输入支持的序列长度是做微调时需要了解并注意的,很多人在准备数据集后,往往会忽略数据集的大小与模型上下文长度的限制,因此导致微调训练效果不理想。

以最近开源的GLM-4-9B-Chat-1M为例,该模型支持1M的上下文输入,在目前来说算是最长的上下文序列了。

单卡下的相关测评:

GLM-4-9B-Chat-1M

精度显存占用PrefillingDecode SpeedRemarks
BF1675 GB98.4s2.3 tokens/s输入长度为 200000

虽说模型是支持1M上下文的输入,但是从机器配置与模型性能角度来考虑,建议是200K的上下文最好;需要超过200K的上下文,则需要考虑多卡部署推理,如基于vLLM等框架。

多卡部署的情况,以基于vLLM框架为例,关键的参数如下:

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# GLM-4-9B-Chat-1M
# max_model_len, tp_size = 1048576, 4
# 如果遇见 OOM 现象,建议减少max_model_len,或者增加tp_size
max_model_len, tp_size = 1310721
model_name = "THUDM/glm-4-9b-chat"
prompt = [{"role""user""content""你好"}]

即在4个并行的多卡下推理,支持的最长上下文是1M。

对于max_model_len参数的单位,即最长上下文参数的单位,一般是B。1 M = 1024 KB = 1024  * 1024 B;所以这里的参数值是 1048576。

更详细的内容,可参考GLM4官方文档:readme[2]basic_demo[3]

Loss计算

Llama-Factory微调框架的loss计算代码路径:code[4]

微调数据集展开后的格式为对话对,即 Q1 +A1 + Q2 +A2 + .... Q表示用户的输入内容;A表示AI的回复响应。

模型会将文本id化,即编码输入与输出。

在Loss计算的代码中,labels列表构建方式如下:

labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]

这里的labels列表对于每一轮对话(即每对Q和A),都会添加对应的target_ids(即A的编码)和eos标记。因此,对于每一轮对话的A部分,都会计算损失。

具体来说,代码的执行逻辑如下:

  1. 对于每个对话对(Q和A):
  • source_ids 对应 Q 的编码。
  • target_ids 对应 A 的编码。
  • labels 列表中添加 [IGNORE_INDEX] * len(source_ids),表示在 Q 部分的损失被忽略。
  • labels 列表中添加 target_ids 和 tokenizer.eos_token_id,表示在 A 部分计算损失。
  • 拼接后的示例:
    • Q1 + A1 + Q2 + A2
    • IGNORE_INDEX * len(Q1) + A1 + IGNORE_INDEX * len(Q2) + A2

    这意味着每一轮对话的 A 部分(即 A1 和 A2)都会计算损失,而不仅仅是最后一轮对话的 A 部分。

    其结构图可理解为:


53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询