AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


如何提升大模型的SFT效率与效果?
发布日期:2024-05-09 07:20:03 浏览次数: 2705 来源:NetRookie



 大模型的SFT技术应用的十分广泛,有非常多的公司从开始的prompt engineering开始接触大模型,然后随着开源代码框架的生态的井喷式发展,微调的技术成本已经非常低了。所以我们研究的事项得更进一步,如何提升大模型SFT得效率和效果?即怎么让效果更好,或者怎么在效果相同的情况下,需要得样本量或者人工参与量更少,这篇文章的目的,一是要review下主要的sft方法,二是在过程中,尝试思考出比较合理得方案,供读者进行工业落地的选型。


Takeaways:

  • 什么是微调:特定任务/特定领域/few shot/蒸馏/parameter efficient/动态微调

  • LLaMa/Yi/Baichuan2/Telechat等开源基础模型公开的sft方法

  • 基于SFT的偏好学习:DPO、Chain of Hindsight



01 



什么是微调



    微调的概念,严格定义为使大语言模型更加适应特定的任务,微调的策略主要包含:

  • task specific adaptation:调整通用模型适配特定任务,主要包含:数据收集,微调训练,评估与优化三阶段

  • domain specific finetuning:调整通用模型适配特定领域,主要包含:收集领域数据,预训练+微调训练,评估迭代

  • few shot learning:通过高质量的少样本进行上下文学习

  • 知识蒸馏:通过更大的teacher指导小模型进线迁移学习优化

  • 多任务学习

  • Parameter-Efficient微调:如lora,使用更少的参数变化带来特定场景任务的优化。

  • 动态微调:需要使用实时数据进行模型的优化,比如根据用户实时数据优化模型

02



开源大模型的SFT优化


重新温习开源鼻祖:llama2 and llama3 and  LIMA


LLaMa 2

  • "Quality is all You need:We stopped annotating SFT after collecting a total of 27,540 annotations"

  • System Message训练:将system message这种全局功能放在模型每一句话前面

  • Safety Context Distillation:在提示前加上安全提示来生成更安全的模型回复,例如:你是一个安全且负责任的助手;然后在没有提示的情况下对模型进行微调

  • 微调细节:学习率 2e-5、weight decay=0.1、batch size=64、seq length=4096、epoch=2


LLaMa 3

  • "combination of supervised fine-tuning (SFT), rejection sampling, proximal policy optimization (PPO), and direct policy optimization (DPO)"。使用经典的指令微调和偏好对齐的排列组合,说明这些方法的潜力在我们实践中还没被挖掘充分,方法已经够好去做项目的落地。

  • 数据质量对sft和偏好对齐方法都有非常巨大的影响;数据迭代很多轮,以保证质量

  • PPO和DPO对推理和代码任务有很大提升。发现:当模型很难搞定推理任务时,模型有时候能产生推理路径,也就是能产生正确大答案只是不知道怎么去选择,偏好学习能加强这种选择。


LIMA:

  • 超级对齐假设:"A model’s knowledge and capabilities are learnt almost entirely during pretraining, while alignment teaches it which subdistribution of formats should be used when interacting with users."

  • Data Diversity, Quality, and Quantity 中 Diversity和Quality最重要

  • 训练细节:1000条高质量样本,使用EOT作为多轮中每轮的结束符号,避免和预训练过的EOS Token有冲突。epochs=15;AdamW, beta1 = 0.9, beta2 = 0.95, and weight decay of 0.1,学习率从1e−5 线性衰减至1e−6,batch size=32, seq_length=2048.


后来者:主流开源大模型提到的SFT是怎么做的?baichuan mistral qwen deepseek yi 电信 以下只提炼出潜在的个人觉得对sft工作有帮助的点


baichuan2:

  • 做safety有个细节的工作,分成了7个大类,每个类别1万条。"To ensure comprehensive coverage within each category, We ask human annotators to generate 1,400 data samples. This was further expanded through self-instruction and cleaned by humans for fluency, resulting in 70,000 total samples with 10,000 per category"

  • 另外,baichuan2 reward也做了类似的事情,足以可见对场景应用层次的细化程度。 “We devised a three-tiered classification system for all prompts, consisting of 6 primary categories, 30 secondary categories, and over 200 tertiary categories.”


mistral 7b:

  • 使用开源的数据和方法微调,没有发现透露出更多的细节


电信大模型:

  • 人工标注:得到100,000 supervised fine-tuning samples

  • trick1:"noisy embeddings for enhanced model performance in scenarios NEFTUNE"


    • NEFTune(Noisy Embedding Instruction Finetuning):

    • 主要做法:给embedding添加噪声,这种做法在传统NLP里面也有过类似的研究,没想到大模型也同样适用的很不错。直接看代码比较好理解。


      from torch.nn import functional as F
      def NEFTune(model, noise_alpha=5) def noised_embed(orig_embed, noise_alpha): def new_func(x): # during training, we add noise to the embedding # during generation, we don't add noise to the embedding if model.training: embed_init = orig_embed(x) dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) mag_norm = noise_alpha/torch.sqrt(dims) return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm) else: return orig_embed(x) return new_func ##### NOTE: this is for a LLaMA model ##### ##### For a different model, you need to change the attribute path to the embedding ##### model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha) return model
    • 主要结论:

      NEFTune Can Improve Chat Models:提升了conversational ability and answer quality; NEFTune在知识Capabilities上也基本没有损失,所以可以认为是一个很好的对话能力提升trick。

       

  • trick2:"multi-stage long-context training to expand TeleChat’s context window":  多阶段逐步提升上下文窗口, 预训练4k,继续sft到 3000步学习率3e-4到8k、1000步4e-5到16k,batchsize都是8M;采用了tensor并行和pipeline并行;再使用NTK-aware外推方法提升到 96k


qwen:说的不多,放一些训练参数吧


  • "AdamW optimizer, with the following hyperparameters: β1
    set to 0.9, β2 set to 0.95, and ϵ set to 10−8. The sequence length is limited to 2048, and the batch size is 128. The model undergoes a total of 4000 steps, with the learning rate gradually increased over the first 1430 steps, reaching a peak of 2 × 10−6."

  • "To prevent overfitting, weight decay is applied with a value of 0.1, dropout is set to 0.1, and gradient clipping is enforced with a limit of 1.0."


deepseek:


  • 数据:"helpful data contains 1.2 million instances, with a distribution of 31.2% for general language tasks, 46.6% for mathematical problems, and 22.2% for coding exercises. The safety data consists of 300K instances, covering various sensitive topics"

  • "We fine-tuned our 7B model with 4 epochs, but only 2 epochs for the 67B model, 1e-5 and 5e-6 for 7B and 67B "

  • 当训练集中数学数据集上升时出现重复问题:"the repetition ratio tends to rise as the quantity of math SFT data"。可能是数学数据集学习难度太高?

  • DPO, with a learning rate of 5e-6 and batch size of 512 =》 能提升 open-ended generation skill


Yi:


  • 重视质量,也强调了一句:"Quality is All You Need:Our finetuning dataset consists of less than 10K multi-turn instruction response dialog pairs, with each and every one of the entry constructed and polished over multiple iterations and from user feedback"

  • 提升分布选择:使用WizardLM中的Evol Instruct方法 =》 可以明显减少sft所需要的数据


  • CoT data formatting:采用“Step-Back” pattern ,这里不太清楚是选型的原因是什么,是为了RAG的应用?

  • 减少幻觉的方案:

    • "examine and ensure that the knowledge in the responses is not contained within the model。" 确保回复中的知识不包含在模型中,可能是希望sft和pretrain不产生知识冲突。

    • "eliminate responses that might lead to memorization"  减少会让模型记忆的回复。  

  • 减少重复的方案:"we rewrite the repetitive turns of the responses that usually exist but may be overlooked in the finetuning data。" 改写模型经常重复的轮次的回复

  • 多样性和数据配比怎么保证

    • "encompassing areas such as question answering, creative writing, dialogue, reasoning, mathematics, coding, safety, bilingual capabilities, and others.  " 尽可能增加各种任务类别

    • "InsTag:By designing a diversity-focused sampling algorithm, we carefully balanced the distribution of instructions across various tags 。"  涉及多样性采样方法,在不同tag中进行混合

    • "grid search to determine our data mixture:{1, 1/2, 1/4, 1/8, 1/16, 1/32, 1/64} proportions for each ability"  工作真细致!使用网格搜索调整混合比例

  • 参数:AdamW optimizer with β1 set to 0.9, β2 set to 0.999, and ϵ set to 10−8 ;seq_length=4096, batch size=64, NEFTune 45 for Yi-34B-Chat and 5 for Yi-6B-Chat


03



多任务学习SFT


     这部分提两篇工作,分别是多任务学习、多语言学习:T5的数据配比、 多语言推理场景的sft增强。

T5中多任务学习的采样

  • Examples-proportional mixing:根据每个任务数据集的大小按比例进行采样。实验证明不够稳定。高资源比较受益

  • Temperature-scaled mixing温度采样rm = min(em, K)/ 求和 min(en, K) 。每个任务的比例为rm的1/T, 其中K为认为设置的样本最大值。T=1等价于等比例采样,T越大越接近均匀采样。 实验证明这是个从大多数任务中获得合理性能的方法,其中 T = 2 在大多数情况下表现最佳。建议尝试

  • Equal mixing:均匀混合,相当于不管任务数据的多少,都采样相同的比例。实验证明效果最差

MathOctopus 跨语言数据增强对多任务的提升

《Breaking Language Barriers in Multilingual Mathematical Reasoning: Insights and Observations》这篇论文虽然讨论的是大模型多语言数学推理的问题,但实际上对于我们针对特定能力的优化也有一定的参考,但实际上可能是通过不同语言的语义对齐进行了数据增强。从而提升一定的下游相关任务能力。我们实验发现不同语言的数据进行混合,与单语言微调相比性能确实会得到提升

04



偏好学习SFT


    RLHF PPO之后影响较大的自然是DPO,当然还有很多进一步的优化如KTO、IPO等,但本质上算不上新范式或一个类别的优化,方法还是经典的好,看4月18号发出来的llama3,也是只用了DPO,这里先不多说KTO、IPO之类,有空再研究。

    DPO通过一个简单的分类损失直接优化策略,消除了在微调过程中从LLM采样或进行大量超参数调整的需要。DPO通过增加偏好响应相对于非偏好响应的相对对数概率来更新。DPO感觉确实有点supervised contrastive learning的意思。

损失函数:

训练的时候有两个模型,Ref模型和Policy模型。对比loss的含义为:左边为Good Response 右边为Bad Response

1、左边变大,右边变小 

2、左边变小,右边更小

3、左边变大,右边不是很大

所以训练过程中只要模型相对于reject来说更关注chosen即可,不用关注reward的绝对值。

构造数据:提供prompt, chosen, reject的格式class DPOLoss(nn.Module):    """    DPO Loss    """
    def __init__(self, beta: float, label_smoothing: float = 0.0) -> None: super().__init__() self.beta = beta self.label_smoothing = label_smoothing


def forward( self, policy_chosen_logps: torch.Tensor, policy_rejected_logps: torch.Tensor, reference_chosen_logps: torch.Tensor, reference_rejected_logps: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps logits = pi_logratios - ref_logratios
        losses = ( -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing )
loss = losses.mean() chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return loss, chosen_rewards, rejected_rewards
Reference: https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/models/loss.py


CoH:Chain of Hindsight 事后诸葛亮

    "human preferences into rich and detailed feedback:Chain of Hindsight (CoH)",它通过将所有类型的反馈转换为句子序列,并利用这些序列对模型进行有监督微调,从而让模型能够理解并有效地利用这些反馈。

05



思考总结


    无论何时,数据是第一位,要考虑质量、多样性、数量的权衡,对于落地来说,按项目按业务分层级分类的策略是非常有必要的。对于推理能力提升,需要改造CoT数据格式提升推理能力,通过Evol Instruct产生更多丰富样本;针对幻觉问题和重复问题,SFT需要对回复内容做过滤;对于多任务来说,需要要进行细粒度的数据配比和跨语言混合尝试;对于对话能力,则使用偏好学习通过利用了更多的信息,有可能能提升生成内容的质量,提升泛化能力,更让用户所接受。从本文介绍的这些方法来看,武器库还是很多的,希望能给大家带来一点收获。

欢迎关注,私信或留言交流!

Reference:

  1. Llama2 https://arxiv.org/pdf/2307.09288.pdf  
  2. Llama3 :https://ai.meta.com/blog/meta-llama-3/  
  3. LIMA:less is more for AlignMent
    https://arxiv.org/pdf/2305.11206.pdf
  4. Chain of Hindsight aligns Language Models with Feedback
    https://arxiv.org/pdf/2302.02676.pdf 
  5. Breaking Language Barriers in Multilingual Mathematical Reasoning
    https://arxiv.org/abs/2310.20246 
  6. DPO https://arxiv.org/pdf/2305.18290.pdf
  7. https://github.com/huggingface/alignment-handbook 
  8. TELECHAT TECHNICAL REPORT sft部分
  9. qwen技术报告 sft部分
  10. mistral技术报告 sft部分
  11. baichuan2 sft
  12. deepseek sft
  13. yi sft
  14. 数据增强:
    https://arxiv.org/pdf/2403.02990.pdf
  15. 多样性
    https://ojs.aaai.org/index.php/AAAI/article/view/29955
  16. NEFTUNE https://arxiv.org/pdf/2310.05914.pdf 


53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询