微信扫码
与创始人交个朋友
我要投稿
本文作者:赵呈路
编辑整理:CastellanZhang
最近大模型的长上下文处理能力备受关注,组内同学也做了调研和实验,结合多种优化技术,将大模型的窗口长度从4K提升到60K。
在大模型实际使用过程中,因大模型存在幻觉问题,对于特定领域回答准确率并不高,在实际场景如法律咨询、电商售前售后咨询等方面使用受限,根据我们的评测,在这方面业内一流模型诸如GPT4.0,表现也并不尽如人意。为了增强大模型表现,业内比较常见的方法为检索增强生成(RAG)和微调(Finetune)。RAG顾名思义,通过外接知识库,增强大模型在专业领域知识,并且得益于注意力机制,RAG后模型能够复现知识库中内容,而后通过模型改写输出自然语言。但注意力机制算法训练成本高,正常来说随着长度扩展,显存消耗是平方级增长的,因而模型存在窗口限制,这个窗口也被称为大模型上下文长度,通常市面上开源模型窗口长度大多在8-32K左右,因此外挂知识库的容量有限。一种解决办法为:结合向量化召回,召回相关知识。因此模型表现上也类似于传统搜索推荐,召回集越多,模型上限也就越高,而这也恰恰意味着模型上下文越长,模型最终效果也会越好。而微调则是通过继续预训练(CPT)和有监督微调(SFT)来给模型注入知识,这个方面我们通过对训练数据进行知识增强提高样本的多样性,在专业领域知识的注入和提取上能做到92%+的准确率,远高于市面上开源模型。但因本文主要探讨如何增加模型上下文窗口,这里知识注入和提取方面就不再详述,感兴趣可以关注后续相关文章。
目前市面上大多数开源模型训练数据长度在4K以内,依靠插值外推等方式可以扩展到32K。考虑到大模型大多数能力来自预训练阶段,因此要想获得100K的上下文窗口,就需要大量长度在100K以上的训练数据,而attention显存占用是随着训练数据长度呈平方级增长,因此要想使模型具备100K以上的上下文能力,仅通过少量样本在现有硬件8 * 80G 显卡上无法实现。得益于【1】的工作,作者通过5B左右的训练数据,继续预训练模型,使得模型在英文领域具备128K左右的上下文的能力,这个训练代价远小于从头开始训练一个长上下文的模型。【2】【3】提到的flash-attention 以及 flash-attention 2 方法又使得模型在attention计算过程中,显存占用和训练时间进一步缩减,特别的 flash-attention 2 使得在一个8*80G的A100上,训练一个支持上下文长度在100K左右的模型变得可能。【4】提出的DeepSpeed-Ulysses又可以进一步降低模型显存,而【5】提出的想法理论上可以使模型具备无限长的上下文能力。文章【6】提及制约大模型(如llama、llama2)长文本能力的原因,主要在于位置编码ROPE会随着距离进行衰减,导致模型在上下文长度增加时,位置编码区分度小,因此一些能保证位置编码区分能力的方法显得特别重要,【6】提出通过减小旋转角度来提高位置编码的表示能力,【7】提出插值方式来提高模型外推能力。得益于上述工作的开源,使得我们在一台 8 * 80G A100的机器上训练一个100K长文本模型变得可能。
大模型长文本能力来自于预训练,那么预训练阶段不同文本长度的loss情况直接决定模型对长文本的支持能力,因此需要测算模型在不同文本长度下loss收敛的情况,以及如果单纯改变训练文本长度,大约需要多少step模型可以收敛至稳态。
参考文章【6】给出的loss公式
设置不同的 ,来绘制不同上下文长度下的验证集loss,从这里可以外推出,如果想要128k的上下文长度,7B模型loss应该在1.4以下。
第二个就是判断改变不同训练文本的长度后,能否在有限的step内把loss降到收敛值附近。
总体可以看出如果改变训练文本长度,大约需要1000个step模型就能达到收敛的稳态。如果按照4M token计算,模型大约需要4B训练数据就能达到想要的loss。
为了让模型支持100K左右的上下文能力,并且不损伤模型原本的通用能力,我们按照原始模型的数据配比,分别从新闻、问答、图书、wiki、code等领域收集约2B左右的数据,其中长文本的占比大约为70%,数据配比上82%来自网页抓取,4.5%来自code、4.5%来自书籍、4.5%来自Wiki、2%来自问答、2.5%来自法律文书和政府以及企业年报。
模型方面选择llama 2 7B进行实验,参考文章【6】【7】的方法对位置编码进行改进。训练参数上,因本身数据源为中文,和llama使用的数据的数据分布上存在一定差异,因此不同于之前预训练模型constant学习率,我们采用30步warm-up让llama模型达到预定的 ,而后lr_scheduler_type 采用 cosine。在batch_size上,采用与原始llama相同的4M token数量进行学习。
通过大海捞针实验:
从大海捞针实验可以看出,模型能扩展到60K左右的长下文长度,但过长后模型会存在问题,可能和模型没有收敛以及编码衰减相关,后续需要进一步改进这两方面。
我们知道GPU擅长进行矩阵乘法计算,GPU上存在一块内存空间名字为SRAM,类似于缓存能够快速读写,但空间有限,HBM空间(也就是显存)较大但读写较慢,从下图可以看出SRAM读写速度是普通显存HBM的10倍以上。
而GPU的计算带宽是远高于显存的读写带宽的,例如,A100-40GB SXM的计算带宽为312 TFLOPS,显存带宽为1555GB/s,因此定义算术强度为
,如果算术强度高于201,则计算受限,如果低于则显存受限,我们考虑标准的attention其算术强度为
其中N为序列长度,d为embedding维度,考虑到N远大于d,因此最终结果约等于 ,而一般经过multi-head之后的embedding维度大约在64左右,因此整个计算是 内存受限。因此如果能够把数据放置在SARM上并减少对于显存的访问次数,不仅能够降低显存占用还可以加快attention的计算。但缓存的大小有限,我们假设一块SARM能够存储1000个数据,向量Q、K维度均为100*10。而2*100*10的数据点是无法存放在SRAM上,因此必然需要从显存中读写数据。但是如果能够对矩阵进行分块,把 100 * 10的矩阵拆分成2个 50 * 10 的小矩阵,这样就能直接从SRAM读取。标准的attention计算方式如下:
从标准的attention计算方式上可以看出,假设单个样本在经过multi-head切分后embedding长度为d,序列长度为N。对上述序列做attention,如果直接把中间矩阵存放到显存上需要 空间,而且从显存读写次数共计8次。而分块后能明显减低显存占用,假设分成两块,即使不改进attention算法,显存占用将降低到1/4。理论很美好,但这里存在一个问题,分块计算attention是否与整体计算attention是等价。attention步骤,如:内积、mask、dropout等分块与不分块类似,而一旦涉及softmax,直觉上可能会出现问题,分块会导致序列长度变短,因为分块必然是从序列长度这个维度进行切分,因此需要找到合适方法,让分块softmax等价整体计算softmax。
传统softmax容易出现极化导致梯度消失和数值不稳定的现象,对于float32和bfloat16来说, 当 时,就会变成inf,发生数据上溢的问题。为了避免发生数值溢出的问题,保证数值稳定性,目前主流的softmax版本为 safe softmax版本:
这里的N维向量 为 的第k行和 相乘的结果。
在传统的safe-softmax的基础上,大约需要3次循环能完成online-softmax的计算,整个循环如下:
可以看出相比正常的attention方式,safe-attention 复杂度主要来自 计算,因为在计算attention过程中只需要最终结果 ,因此可以把 计算改写成下面形式:
因此safe-attention的计算方式可以变成:
到这里似乎无法优化,因为 我们需要的变量,而这个值计算非常依赖N,但是我们注意到最终attention的计算公式为:
我们最终需要也是 ,即各项的加权平均,而非各个分项的,因此我们把 改成迭代形式:
因此attention计算方式从原本的3次循环变成单次循环:
初始化:,, 为d维零向量。最后得到的 即为我们所要的结果。
有了这个单次循环的函数 ,我们就可以将 , 切分,分块计算。比如切分为两块,,,这里的 都为 的矩阵。
为了简化起见,我们还是只关注 的第k行,设为行向量 ,则
都是N/2维度的向量,我们可以调用一次函数 ,得到 ,然后再调用一次函数 ,只不过这次调用的初始化是利用上次的结果 ,可以验证 得到的 和直接调用 的结果 是相等的。
这就证明了分块计算attention与整体计算attention是等价的。flash-attention是一种无损的计算方式,而且整个算法就可以把空间复杂度控制在 上,极大降低了显存的消耗。 将这个分块的思想继续推广,就得到完整版的flash-attention算法,整个前向过程如下:可以看到整体计算上,为了减少计算次数,提高效率,伪代码中尽可能利用已经计算好的数据。
反向传播时,标准的attention模型因为本身存储了中间矩阵S和P,因此反向传播时正常计算即可。而flash-attention因为降低显存的目的,并没有保存中间结果,而是通过Recompute的方式,整体反向传播如下:
到这里我们要分析下flash-attention计算量和显存占用,写回显存的占用量大约是,计算量上与标准的attention相同,不过加上重计算flash-attention的计算量反而增加了,但得益于显存访问的减少,其计算效率得到较大提升,更为重要的把显存占用从 降低到这一点对长上下文窗口的模型非常重要。
flash-attention对于旋转位置编码是能够支持的,因为旋转位置编码是先作用到Q、K上,相当于在flash-attention的外部,而ALiBi是直接作用到attention-score上面的,需要自己写flash-attention代码后才能实现。
相较于flash-attention,V2版本的flash-attention主要在下面三个方面对其进行改进:
首先来回顾下flash-attention计算方式:
从这个示意图中可以看出scaling 部分可以挪到最后去做。因此原本的
可以改为:
可以在最后一步时处理来得到合适的结果,这样能减少非矩阵乘法运算,进一步提高模型运算效率。
对于反向传播阶段,flash-attention2 避免直接存储以及的指数和,作者使用来代替原本反向传播的13行。最终flash-attention2的前向传播和后向传播的伪代码如下:
后向传播:
为降低使用attention预测时KV的缓存大小,flash-attention 2 通常不采用常规的multi-head的方式,而是共享同一个key和Value或者一个组共享一个Key和Value。因此在反向传播时,需要把共享的K和V的梯度进行加和。
作者提到在第一版的flash-attention中,并行化是发生在batch_size和head上,即attention的矩阵计算,作者采用的是一个线程分区去计算一个attention的head,这样整体就需要batch_size*head个线程分区开销。每一个线程开销是都是在流式多任务处理器上跑的。目前A100上大约有100个这样的流式多任务处理器,如果分区使用量能达到80个左右,那么几乎可以利用全部的GPU运算资源。在长文本训练过程中由于显存的限制,bacth_size和head通常比较小,作者加入了序列长度来进行并行计算,这样能进一步提高资源利用率,进而提高运算效率。
前向传播上:
FlashAttention算法有两个循环,KV在外循环,Q在内循环。FlashAttention-2将Q移到了外循环,KV移到了内循环,这样算法就不用相互通信去处理Q的问题,所以外循环可以放在不同的thread block上。这个交换的优化方法是由Phil Tillet在Triton提出并实现的。
FlashAttention forward pass. 如下图 (a) 所示,根据算法原理,每一次内循环,都需要重新写入,这就导致每个warp需要从HBM频繁读写来求得最后的累加和,是非常低效的。
FlashAttention-2 forward pass. FlashAttention-2将移到了外循环,KV移到了内循环,并将Q分为4个warp,所有warp都可以访问K和V。这样做的好处是,原来FlashAttention每次内循环会导致的不断变化(需要通过HBM读写),现在每次内循环处理都是同一个,此时可以存储在SRAM上的,代价远小于读取HBM。
transformer通常以位置编码来表示某个token在句子中的前后关系。每个字的位置编码存在差异,而后再把位置编码值叠加到token的embedding上。【6】发现如果采用原始的ROPE的编码,无论继续预训练(CPT)阶段的文本有多长,模型都无法获得希望的长文本能力。一种合理的猜测是位置编码存在远程衰减从而限制了模型在长距离的表达能力。首先【6】把Q和K想象成值为1的列向量,向量的维度为1*4096,在和ROPE位置编码相加后计算attention-score,发现随着距离的增加,原始的RoPE存在非常明显的衰减,导致softmax前attention score非常小,其均值在0左右,通过之前搜索推荐的经验也可以知道attention分数过大和过小都会导致梯度消失的问题,导致模型是无法优化的。
因此必须增加在softmax前attention-score的数值大小,一种方式可以通过乘以某个常数来增大attention score分数,来使得整个transformer梯度更合理些,另外一种是减小位置编码的衰减,从理论上讲后者会更占据优势,【6】通过增大ROPE的公式中的b来减小旋转角度,进而减小位置编码的衰减,提高softmax前 attention score 的分数,避免梯度消失的现象。
首先来看下几种不同形式下ROPE的表述方式:
其中i是虚数单位,x代表embedding、t代表旋转的角度,j代表维度。之所以是这种形是因为假设编码中位置分别为m和n,可以发现两个不同位置内积,只取决于位置的相对顺序。即m-n,这也意味着ROPE是相对位置编码,与绝对位置无关,非常适合进行外推实验。
这个证明起来相对比较复杂,首先把一个数变成其复数的形式:
假设对位置在m和n的两个token进行内积操作:
可以看到位置编码,只和旋转角度m-n的相对位置相关,但实验中发现距离越长衰减越厉害,导致attention score 分数过低,于是【7】插值方法来抵消掉m-n扩大带来影响,ABF方法也是类似,都在补偿m-n的过大带来衰减。
PI的sin夹角为:
ABF的sin夹角为:
可以看出增大b会减小旋转角度,缓解长距离衰减的问题。
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-03-30
2024-04-26
2024-05-10
2024-04-12
2024-05-28
2024-05-14
2024-04-25
2024-07-18
2024-04-26
2024-05-06
2024-12-22
2024-12-21
2024-12-21
2024-12-21
2024-12-21
2024-12-20
2024-12-20
2024-12-19