AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


大模型的训练过程总结
发布日期:2024-03-23 21:22:21 浏览次数: 3414 来源:AI少数派报告


上图来自 Andrej Karpathy,深度学习的大拿,目前在Tesla。这张图的信息量相当大,通过该图能让我们对大模型的整个训练过程有一个总体的了解。

从该图可知大模型训练主要有4步:

  1. Pretraining — 预训练阶段
  2. Supervised Finetuning(SFT) — 监督微调,也叫指令微调阶段
  3. Reward Modeling — 奖励模型训练阶段
  4. Reinforcement Learning(RL)— 增强学习微调阶段

下面分别就这四个阶段进行说明。

预训练阶段

这个阶段的产出物是基础模型(base model),基础模型通常不会被直接使用,因为它只能完成续写,无法完成特定的任务。比如你问它中国的首都是?它可能会输出一系列的选项,如:a) 上海,b)北京,c)巴黎 ……,因为训练语料可能就包括了这样的选择题。

这个阶段是最消耗算力的阶段, 基本上99%的算力用在这个阶段,主要因为这个阶段训练的数据量巨大,最近发布的大模型训练数据基本都达到了2T~3T的token。这些语料形式多样,包括网络内容、论文、代码,再加上多语种,目前国内开源大模型的训练语料一般是中英双语。根据MPT公开的资料,训练一个MPT-7B Base基础模型动用了440张A100共训练了9.5天,对应的成本达到20w刀。

为了让大模型能具备特定的能力,如对话,就必须对大模型进行微调,那么就进入到下一个阶段:监督微调或叫指令微调。

指令微调阶段

这个阶段的难点不再是对算力的高要求,转而对微调所需的语料质量有非常高的要求,对语料的总体要求是少而精,如上图所示问答对(prompt和答案)的数量一般在10K~100K,这些语料通常是人工编写的,也有利用chat-gpt这种超牛大模型输出问答语料,目前网上也能找到比较丰富的开源的指令微调语料。以MPT-7B-Chat模型为例,该chat模型是基于MPT-7B Base基础模型微调出来,微调的语料包括:ShareGPT-Vicuna, HC3, Alpaca, Helpful and Harmless, 及Evol-Instruct这些开源语料。对于一个chat模型来说,它应该能回答用户的提问或根据用户的指令进行输出,所以语料中包括了各种问答对、指令集等内容,例如Alpaca的语料形式如下:

{
    "instruction""Create a classification task by clustering the given list of items.",
    "input""Apples, oranges, bananas, strawberries, pineapples",
    "output""Class 1: Apples, Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples",
    "text""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a classification task by clustering the given list of items.\n\n### Input:\nApples, oranges, bananas, strawberries, pineapples\n\n### Response:\nClass 1: Apples, Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples",
}

“text”部分是一个prompt模板,将instruction, input及output进行格式化。该prompt作为微调语料,微调后的模型就能接受这样的prompt格式,完成用户的指令,比如将以下prompt喂给微调后的chat模型:

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
sort the the given numbers

### Input:
20, 19, 33, 1

### Response:

模型会输出:1, 19, 20, 33。

这个阶段完成后其实已经能获得一个可以上线的大模型了(SFT模型)!对于那些需要私有化部署并只要求大模型完成特定任务的场景,指令微调出来的大模型已经完全满足要求。比如nl2sql的场景。

后面两个阶段的主要目的是让大模型有更好的表现,并且与人类喜好、意识形态进行对齐(有用、诚实和无害)。目前比较主流的技术就是RLHF(reinforcement learning from human feedback)— 也就是后面两个阶段的工作。一般企业内部私有化部署的大模型很少会去做RLHF。首先,RLHF很难,需要人工对模型的输出进行打标,成本太高。其次,没有这个必要,对于那些nl2sql,增强分析,RAG等场景基本不涉及“政治正确”这个维度。当然,对于通用大模型或各种客服类的模型与人类价值观对齐还是必须的。

奖励模型训练阶段

这个阶段会输出一个奖励模型(RM模型),该模型不直接用于业务场景,RM模型的作用是对SFT的输出进行打分—根据人类的习惯进行排序。这么做有什么意义呢?举例来说,你可以要求大模型生成一份营销文案,但这个文案是否满足你的要求,写得是否足够好是一件很难量化的事情。然而,如果拿三份文案来让你挑选,那么你通过横向比较,很容易就能选出最好的那篇。

如何训练出RM模型?这个阶段也涉及到大量的人工作业,主要是对SFT模型的输出进行排序。具体来说就是用同一个prompt从SFT模型获取多个输出,然后人工对这些输出进行排序。从上图可知,这个工作量相当大,需要准备100K~1M条待比较数据并且对数据质量的要求很高,要知道这些记录都需要人工进行比较排序,每组比较往往需要不同人来打标进行交叉比对。

有了这些排序后的数据就可以训练出一个RM模型,RM模型接受SFT模型的输出并给出一个奖励分值,SFT模型的输出越符合人类的价值观则RM模型输出的奖励分值越高,反之则越低。举个例子,向SFT模型输入prompt:“狗是人类的?”,SFT模型给出两个输出:

  1. 人类的朋友
  2. 毛茸茸的猛犸象

将prompt和以上两个输出喂给RM模型,RM会给第一个输出高分,给第二个输出很低的分数。

增强学习微调阶段

有了RM模型就能通过增强学习(RL)来微调模型,以实现大模型同人类偏好对齐。这个阶段依旧需要不少人工作业。如上图所示,需要编写10K~100K条prompts并且对质量的要求很高。RL微调阶段的主要目标是获取最高的奖励分值。整个微调过程大致如下:

  1. 将准备好的prompt分别喂给SFT模型和RL微调中的模型,并分别获取输出y1和y2。注意,这里还是会用到SFT模型,具体原因会在下面给出
  2. 将y2输入至RM模型获取奖励分值r,这个分值越高代表模型对齐效果越好
  3. 将y1与y2计算散度k,例如KL divergence,这个值的作用类似regularization,y1与y2的散度越高,说明RL微调模型与SFT模型这两者表现的偏离越大,也表示RL微调模型为了获取更高的奖励分数出现了过拟合
  4. 以(r-λk)作为奖励函数,通过更新RL微调模型的参数最大化奖励分值。

总结

以上4个阶段构成了完整的GPT模型训练的pipeline,从中可以看出训练大模型是一个非常艰巨的任务,例如对庞大算力资源的要求、对高质量语料数据的要求。另外,在训练大模型的时候一般需要基于一些优化框架,如DeepSpeed,这些工程化方面的任务也有不少坑。因此,对于一般的企业而言通常不建议自己训练基础大模型,如果必须进行私有化部署,可以根据实际情况选择一款开源大模型,如有必要可基于开源大模型进行微调,通过这个途径可以用比较少的投入来高效地落地大模型应用。




53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询