AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


深入解析LLM预训练与SFT对齐:Loss函数差异与代码解析
发布日期:2024-12-16 07:33:57 浏览次数: 1630 来源:电商小ai



LLM(Large Language Model,大型语言模型)在预训练和对齐阶段,虽然都使用loss函数来指导模型学习,但两者在loss的设计和目标上存在显著差异。

1. 预训练阶段:

  • 目标: 学习语言的通用表示,掌握语法、语义、知识等。
  • 数据: 海量、未标注的文本数据,例如书籍、网页、代码等。
  • Loss函数: 通常使用自监督学习方法,例如:
    • Masked Language Modeling (MLM): 掩盖句子中部分词语,让模型预测被掩盖的词语。
    • Causal Language Modeling (CLM): 根据前面的词语预测下一个词语。
  • Loss特点:
    • 关注模型对语言结构和知识的理解。
    • 数值较大,因为模型需要学习大量信息。
    • 随着训练的进行,loss逐渐下降,表示模型对语言的理解能力不断提升。
  • 预训练Loss代码:
    • transformers库中的源代码,包含在trainer中的compute_loss,会在预估的prediction_step和training_step函数中被调用,实现的源代码在LabelSmoother类中,具体实现如下:
 544 @dataclass
 545 class LabelSmoother:
 546     """
 547     Adds label-smoothing on a pre-computed output from a Transformers model.
 548
 549     Args:
 550         epsilon (`float`, *optional*, defaults to 0.1):
 551             The label smoothing factor.
 552         ignore_index (`int`, *optional*, defaults to -100):
 553             The index in the labels to ignore when computing the loss.
 554     "
""
 555
 556     epsilon: float = 0.1
 557     ignore_index: int = -100
 558
 559     def __call__(self, model_output, labels, shift_labels=False):
 560         logits = model_output["logits"if isinstance(model_output, dict) else model_output[0]
 561         if shift_labels:
 562             logits = logits[..., :-1, :].contiguous()
 563             labels = labels[..., 1:].contiguous()
 564
 565         log_probs = -nn.functional.log_softmax(logits, dim=-1)
 566         if labels.dim() == log_probs.dim() - 1:
 567             labels = labels.unsqueeze(-1)
 568
 569         padding_mask = labels.eq(self.ignore_index)
 570         # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
 571         # will ignore them in any case.
 572         labels = torch.clamp(labels, min=0)
 573         nll_loss = log_probs.gather(dim=-1, index=labels)
 574         # works for fp16 input tensor too, by internally upcasting it to fp32
 575         smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
 576
 577         nll_loss.masked_fill_(padding_mask, 0.0)
 578         smoothed_loss.masked_fill_(padding_mask, 0.0)
 579
 580         # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
 581         num_active_elements = padding_mask.numel() - padding_mask.long().sum()
 582         nll_loss = nll_loss.sum() / num_active_elements
 583         smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
 584         return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
  • shift_labels:是否需要位移计算,logits = logits[..., :-1, :]预估值,从第0个到倒数第二个,labels = labels[..., 1:]为label,原始文本,从第1个到结束,label中的第0个为输入,预估结果从第一开始,计算loss也是。
  • log_probs:softmax函数的计算,转换为概率分布。
  • padding_mask:padding_mask = labels.eq(self.ignore_index),label中特殊标记token(pad_token_id)为padding,可以计算出padding_mask来,刨除padding_mask外的,参加loss计算。
  • nll_loss:核心计算nll_loss = log_probs.gather(dim=-1, index=labels),log_probs是一个形状为 (batch_size, sequence_length, vocab_size) 的张量,表示模型对每个词的预测概率的对数。
    • batch_size:批处理大小,即一次处理多少个样本。
    • sequence_length:序列长度,即句子中有多少个词。
    • vocab_size:词汇表大小,即模型认识多少个不同的词。
    • 例如,log_probs[0, 2, 500] 表示模型预测第一个样本中第三个词是词汇表中第500个词的概率的对数,从log_probs到nll_loss,主要做了平滑和去掉padding。
    • labels:形状为 (batch_size, sequence_length) 的张量,每个词的真实标签 (ground truth),例如,labels[0, 2] 表示第一个样本中第三个词的真实标签。
    • gather(dim=-1, index=labels):从log_probs获取labels对应位置的值;dim=-1:表示在最后一个维度(即 vocab_size 维度)上进行操作;index=labels:使用 labels 张量作为索引来获取log_probs 中的值。


2. 对齐阶段 (SFT: Supervised Fine-Tuning):

  • 目标: 将预训练模型的能力迁移到特定任务,例如对话生成、文本摘要、机器翻译、LLM落地到垂类业务场景等。
  • 数据: 针对特定任务的标注数据,例如对话记录、摘要文章、翻译文本、垂类业务数据等。
  • Loss函数: 通常使用监督学习方法,根据具体任务选择合适的loss函数,例如:
    • Cross-Entropy Loss: 用于分类任务,例如情感分析、意图识别等。
    • Mean Squared Error (MSE) Loss: 用于回归任务,例如文本评分、机器翻译质量评估等。
  • Loss特点:
    • 关注模型在特定任务上的表现。
    • 数值相对较小,因为模型只需要微调预训练的参数。
    • 随着训练的进行,loss逐渐下降,表示模型在特定任务上的表现不断提升。
  • 对齐sft Loss代码:
    • 对齐sft loss对不同的训练框架实现稍微有些区别,但本质都是一样的,都会先对prompt部分剔除或者mask掉,然后调用预训练transormers库的loss计算,以开源框架LLaMA-Factory中的sft进行解读。
    • LLaMA-Factory中的sft loss计算代码在train/sft/trainer.py中,具体实现在prediction_step函数中,详细如下:
 81     @override
 82     def prediction_step(
 83         self,
 84         model: "torch.nn.Module",
 85         inputs: Dict[str, Union["torch.Tensor", Any]],
 86         prediction_loss_only: bool,
 87         ignore_keys: Optional[List[str]] = None,
 88     ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
 89         r"""
 90         Removes the prompt part in the generated tokens.
 91
 92         Subclass and override to inject custom behavior.
 93         "
""
 94         labels = inputs["labels"if "labels" in inputs else None
 95         if self.args.predict_with_generate:
 96             assert self.tokenizer.padding_side == "left""This method only accepts left-padded tensor."
 97             labels = labels.detach().clone() if labels is not None else None  # backup labels
 98             prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
 99             if prompt_len > label_len:
100                 inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
101             if label_len > prompt_len:  # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
102                 inputs["labels"] = inputs["labels"][:, :prompt_len]
103
104         loss, generated_tokens, _ = super().prediction_step(  # ignore the returned labels (may be truncated)
105             model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
106         )
107         if generated_tokens is not None and self.args.predict_with_generate:
108             generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
109             generated_tokens = generated_tokens.contiguous()
110
111         return loss, generated_tokens, labels
112
113     def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
114         r"""
115         Pads the tensor to the same length as the target tensor.
116         "
""
117         assert self.tokenizer.pad_token_id is not None, "Pad token is required."
118         padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
119         padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor  # adopt left-padding
120         return padded_tensor.contiguous()  # in contiguous memory
  • 部分参数含义:因为计算sft计算loss时,prompt部分不参与,需要从labels中刨除掉,使用pad_token_id特殊token掩盖掉prompt
    • padding_side:padding_side == "left" assert为左边padding,否则报错(需要可以自行修改,但需要把一些padding的逻辑都一起改了)。
    • if prompt_len > label_len则在prompt 张量中将prompt部分用pad_token_id特殊token mask掉,具体实现在_pad_tensors_to_target_len中,if label_len > prompt_len 同理。
    • mask掉prompt部分后,调用transformers库基类Trainer的prediction_step,即可计算出输出部分的loss(详细代码看预训练部分loss)。

总结:

特征预训练对齐 (SFT)
目标学习通用语言表示迁移到特定任务
数据海量未标注数据高质量标注数据
Loss函数自监督学习 (MLM, CLM)监督学习 (Cross-Entropy, MSE)
Loss特点数值较大,关注语言理解数值较小,关注任务表现

需要注意的是,以上只是一些常见的区别,实际情况可能更加复杂。例如,有些预训练任务也会使用少量标注数据,而有些对齐任务也会使用自监督学习方法。

总的来说,预训练和对齐阶段的loss函数设计都至关重要,它们共同决定了LLM最终的性能。



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询