AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


聊聊GLM-4-9B开源模型的微调loss计算
发布日期:2024-07-14 06:02:35 浏览次数: 1902 来源:阿郎小哥的随笔驿站


概述

Github官方地址:GLM-4[1]

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

微调

微调格式:

[
  {
    "messages": [
      {
        "role""system",
        "content""<system prompt text>",
        "tools": [
          {
            "name""<tool name>",
            "args": {
              "<arg name>""<arg value>"
            }
          }
        ]
      },
      {
        "role""user",
        "content""<user prompt text>"
      },
      {
        "role""assistant",
        "content""<assistant response text>"
      },
      {
        "role""user",
        "content""<user prompt text>"
      },
      {
        "role""assistant",
        "content""<assistant response text>"
      },
      {
        "role""observation",
        "content""<observation prompt text>"
      },
      {
        "role""assistant",
        "content""<assistant response observation>"
      },
      {
        "role""user",
        "content""<user prompt text>"
      },
      {
        "role""assistant",
        "content""<assistant response text>"
      }
    ]
  }
]

微调源码地址:finetune.py[2]

Loss计算代码:

def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
)
 -> dict[str, list]:

    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []
    # batched_conv 是一个数组
    # conv 是数组内的单个 message
    for conv in batched_conv:
        input_ids = [151331151333]
        loss_masks = [FalseFalse]
        # conv 是数组内的单个 message
        # message 是 单个role json对象
        for message in conv:
            message = process_message(message)
            # 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
            loss_mask_val = False if message['role'in ('system''user''observation'else True
            # 获取 input 文本的数字表示(ids)
            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
            # 计算整句的 mask
            new_loss_masks = [loss_mask_val] * len(new_input_ids)
            # 拼接message中的每段json
            input_ids += new_input_ids
            # 拼接message中每段json对应的mask
            loss_masks += new_loss_masks
        # 追加结尾的 token id
        input_ids.append(tokenizer.eos_token_id)
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                # 添加到label,计算loss
                labels.append(input_id)
            else:
                # -100 不处理,即ignore_index
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        # 截断
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])
    return {'input_ids': batched_input_ids, 'labels': batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
# 数据集拆分遍历
train_dataset = data_manager.get_dataset(
    Split.TRAIN,
    functools.partial(
        process_batch,
        tokenizer=tokenizer,
        max_input_length=ft_config.max_input_length,
        max_output_length=ft_config.max_output_length,
    ),
    batched=True,
)
print('train_dataset:', train_dataset)

Loss计算如下图所示:

总结

相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md[3]

Reference
[1]

GLM-4: https://github.com/THUDM/GLM-4

[2]

finetune.py: https://github.com/THUDM/GLM-4/blob/main/finetune_demo/finetune.py

[3]

dataset_format.md: https://github.com/InternLM/xtuner/blob/main/docs/zh_cn/user_guides/dataset_format.md



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询