AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


小数据,大突破!揭秘仅0.3B个token如何让8B模型逼近GPT-4,长文本开源新纪元
发布日期:2024-10-25 16:16:48 浏览次数: 1599 来源:深度学习自然语言处理



当前,越来越多的研究指出,长文本模型(Long-context Model, LCM)在输出时可能会遇到多种问题,其中最为突出的是幻觉(Hallucination)和指令不遵循(Instruction Unfollowing)现象。以下面的示例来说明:

首先,模型会接收到一段详细的背景信息,例如关于美剧《老友记》(Friends)的某个情节的描述。然后,可能会有用户提出这样的问题:“在《老友记》中,Rachel和Monica的职业分别是什么?”根据背景信息,正确的回答应该是Rachel是一名服务员(waitress),而Monica是一名厨师(chef)。然而,长文本模型有时会出现幻觉现象,错误地告诉你Rachel是一名护士(nurse),或者完全忽视问题,转而讲述一段毫不相关的内容,比如“Jackey喜欢游泳”(当然,这种情况出现的概率相对较低)。总的来说,幻觉现象是最常见的问题。


针对长文本模型所表现出的这些问题,可以借鉴先前的研究,将它们统称为长文本领域的不对齐/偏差现象(Misalignment Phenomenon)。近期,众多研究工作揭示了长文本领域中存在的Misalignment现象(LongHalQA[1], L-CiteEval[2])。


长上下文中的对齐与非对齐(Alignment & Misalignment)

在深入探讨长文本模型(Long-context Model, LCM)的非对齐或偏差现象之前,有必要先梳理一下目前的研究进展。

长文本模型为什么会出现偏差(Misalignment)?

一个较为底层且广泛接受的解释是,长文本模型中包含了检索头(Retrieval Head[3]),这些检索头专门负责根据输入的问题定位长上下文中的关键片段,然后模型可以根据这些关键片段进行回复。(本人也通过这篇论文受益匪浅)。长文本模型出现偏差无非只有两种原因:1)检索头没有定位到关键信息;2)检索头定位到关键信息,但是模型没法处理这些关键信息有效输出

如何做长文本对齐?

很简单,训练是最好的方法。然而,遗憾的是,之前大多数长文本模型的研究都集中在如何更好地扩展模型的上下文窗口(Long Context Scaling),即让模型能够“看到”更多的信息。相比之下,关于如何在已有长上下文窗口的模型上获得更好的效果(Long Context Alignment),相关的研究工作却相对较少。这方面的研究主要集中在制作更高质量的“长文本指令数据”,例如Long Alpaca,或者通过LongLoRA、PoSE等高效微调方法来提升长文本模型的效果。

但问题在于,这些方法主要还是集中在监督式微调(Supervised Fine-Tuning, SFT)阶段。众所周知,针对非对齐问题,SFT可能只能做到“哪里不足补哪里”,其收益相对有限。最终,可能还是需要借助强化学习(Reinforcement Learning)这一大杀器来有效地对齐模型,这一点在下面的实践中也会进行详细阐述。


一种基于强化学习的长上下文对齐策略:LOGO

最近,我们在长文本处理领域尝试了一种新的强化学习对齐策略,名为LOGO(Long cOntext aliGnment via efficient preference Optimization)。

Arxiv & Project 传送门(Code还在整理,近期在开发和贡献OpenRLHF[4]库,应该大概率集成或者继承OpenRLHF库之后再Release,月底之前会release出一版本code):

论文:https://arxiv.org/abs/2410.18533
项目:https://github.com/ZetangForward/LCM_Stack

基于偏好对齐的长文本对齐效果展示

在介绍LOGO方法之前,我们先展示一个preview:该策略仅需在0.3B个token上进行训练,就能够显著提升Llama-3-8B-Instruct模型在长文本实际任务中的表现(这里我们选取的是LongBench数据集),甚至逼近GPT-4这种强闭源模型。

通过图2,我们还能发现一个问题:图(b)中模型的Retrieval Score都大差不差,说明这些模型都能很好的定位长上下文中的关键信息,但是输出结果(Recall Score)确大相径庭。这也和上面的猜想联系起来了:长文本模型可以定位关键信息,但是没法很好的处理结果,换句话说,不知道什么样的结果是对的,什么样的结果是错的。

如何在长上下文场景下做基于强化学习的对齐?

在长上下文场景下,采用强化学习进行训练(题主这里采用的是偏好对齐,DPO[5])需要考虑两个问题:

  1. 训练的可行性?长文本模型通常需要较大的显存和计算资源。传统的直接偏好优化(Direct Preference Optimization, DPO)需要在GPU中同时存储参考模型(Reference Model)和策略模型(Policy Model)。对于参数量庞大且词汇表庞大的模型(训过的应该知道,词表越大,越会有一个GPU Comsumption Peak),这可能导致显存消耗达到峰值,使得训练变得困难。需要开发或优化训练库以适应长上下文模型的训练需求,这可能涉及大工程量(文末会安利一个训练框架~)。
  2. 数据哪里来?在长文本领域,通常缺乏专门标注了对齐/非对齐的数据对。如果需要自己构造这样的数据,长文本模型生成结果的评测可能难以依赖于人工或强模型(如GPT-4),因为这既耗费资金,又存在标注难度。巧妇难为无米之炊,缺少数据自然也就很难在长文本上做偏好对齐了。

LOGO分别从建模目标(Training Objective),长偏好数据构建(Long-Dependency Preference Data Construction) 以及 高效训练(Efficient Training)三个方面介绍了长文本场景下的强化学习训练方法(下面的内容和原文的讲述思路不太一样,但是内容是一样的)。

1. Training Objective

在长上下文场景下,基于强化学习的对齐策略面临着显存和训练数据的双重挑战。近期很多工作表明,直接使用DPO会出现严重的Reward Hacking和生成能力退化的现象,为了解决这个问题,他们在DPO公式的基础上加了很多约束项,这样不仅让DPO变得更加难以训练和优化,在长文本场景下,也加剧了训练部署的困难。我们主要参考了一个Reference-Free且对生成模型友好的训练目标函数:SimPO[6]。SimPO的目标函数可以写成:

其中 是策略模型(待优化的模型), (奖励差异的缩放因子)和 (目标奖励边际)是用来区分偏好和非偏好响应的超参数。

考虑到偏好数据构建的困难,LOGO采取了一种折中策略,即不特定标注错误类别,而是将所有错误都视为负样本。这样,通过扩大负样本空间,一方面避免了标注的困难,另一方面可能使模型更好地学习。因此,LOGO的目标函数可以写为:

这里 是一个偏好数据中的负样本数目,即一条训练数据中包含一条正样本和个负样本。最后,考虑到Reward Hacking和生成能力退化的问题,我们加了一个SFT的正则项(通过 去控制相对强度)优化目标函数:

2. Long-Dependency Preference Data Construction

有了上述的建模目标,构建对应的偏好数据则方便和简单许多,因为我们压根不需要考虑特定的错误类型(题外话:让模型生成错误答案还是非常简单的)。具体可以参考下面这个示意图:

简单来说分为三个步骤:

  • 关键片段打分:对长上下文进行片段(Chunk)的划分,然后通过实体抽取模型(NER)抽取出每个片段中和问题相关的实体,如果实体越多,说明这个片段对最终答案的贡献越大
  • 引导模型生成偏好数据:通过组合不同分数的片段作为上下文(context),引导模型生成偏好数据。例如,可以利用贡献大的片段引导模型生成正确的结果,然后利用一部分贡献大+一部分贡献小的片段引导模型生成“可能出现幻觉的结果”,用完全没有贡献的片段引导模型生成不遵循指令的结果。这样做有两个优点:1)每个片段的长度都不长,实际切分下来每个片段只有1024个tokens,组合8个片段也在Llama3-8B 8K的上下文范围内,便于构造数据;2) 片段长度不长,降低了对最后结果判断正确与否的难度,可以利用GPT4等高级模型进行评测。
  • 位置编码合成:这是一个在长文本领域(可能比较冷门)但是很实用的Trick(苏神、Weifuru、Fuyao等大佬都有相关的工作,可以去搜一搜)。简单来说就是通过调整位置编码来影响Transformer-based模型对序列长度的感知。
  • 正常来说,一段文本都是从0开始标位置,一直标到N-1 : (N是序列长度)。
  • 如果现在标的是 ,通过跳过一些位置编号,模型会认为序列长度是(远远大于 ),从而实现对“无限长”数据的构建。
  • 虽然方便构建,但是会出现中间缺失的片段信息。所以,原文中采用了混合位置编码合成的方式,在长度扩充的同时,尽可能保证每个Batch中所有的位置全部覆盖,细节可以参考原文的Appendix D。

3. Efficient Training

最后就是如何部署高效的训练了。一方面,通过上述的Reference-Free 训练目标函数和位置编码合成的策略可以大大减少所需的GPU显存大小,LOGO采用了LoRA(Low-Rank Adaptation)训练方法,这是一种参数高效的微调技术,它通过在预训练模型的权重上添加低秩矩阵来实现微调,而不是直接修改原始权重。在叠满这些buff之后,所有的训练仅仅在一台A800的机器上完成,且在16小时的训练时间之内,就可以让Llama-3-8B-Instruct在长文本任务上的性能得到提升和改善。

最终的训练结果 & 一些消融实验带来的 Insight

真实任务上的效果 这里我们采用了LongBench

合成任务上的效果,测试LOGO的长文本窗口扩充的能力

短文本任务上的效果,这里我们挑选了MMLU、TruthfulQA、ARC三个常用的测试集,分别覆盖了多种语言任务、事实核查和多项选择问答等不同的挑战。我们发现,如果全量训练模型(下面中间图红色的箭头),长文本模型可能会在长文本任务上取得不错的性能,但这种训练方式往往会导致模型在短文本任务上的能力下降。这种现象可以被看作是“对齐税”。LOGO通过使用LoRA训练方法,有效地避免了全量训练可能带来的问题,所以所以原始的能力非但没有丢失,还小涨了一点。

(部分)分析实验:不同长文本对齐算法之间的比较。我们比较了两种标准的SFT策略:一种是在整个序列上施加交叉熵(CE)损失,另一种是仅在最终预测上施加CE损失。实验结果显示,这两种方法都会导致模型性能达到一个瓶颈。一方面,这可能是由于训练数据的质量所限;另一方面,CE损失仅推动模型输出接近真实标签,但并未教会模型如何避免错误答案。反观LOGO,可以当模型获得一个持续增长的效果,同时,我们还通过检索头计算了Retrieval Score,越大的Retrieval Score表明模型可以定位到更多的重要信息。这表明,LOGO甚至还激活了检索头的能力。

结语

随着人工智能技术的不断进步,长文本模型在理解和生成自然语言方面展现出了巨大的潜力。然而,这些模型在处理长文本时常常会遇到幻觉和指令不遵循等问题,这些问题限制了它们在实际应用中的有效性。

在探索长文本对齐的征途上,位置编码的激活(窗口扩充技术)已经被很好的解决。同时,我们已见证令人振奋的进展,如o1-preview模型的问世,它们在长推理链和真实使用场景中展现出了卓越的性能。然而,这一领域的研究与应用仍然面临着诸多挑战和未知,随着模型推理能力越发被重视,未来的研究可能会集中在如何进一步提升这些模型的推理能力和效率,以及如何将它们更好地集成到实际应用中等。

最后,我依然保持着这个观点: 长文本模型的研究应该从“更长上下文”的竞赛中跳出来,转而探索如何使模型在给定的上下文中做出更精准、更符合人类思维方式的回应。这可能涉及到模型的理解和推理能力的进一步提升,以及如何让模型的决策过程更加透明和可解释。



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询