微信扫码
与创始人交个朋友
我要投稿
大模型时代能够充分利用 GPU 的显存是一项非常有必要的技能。本文将在仅考虑单卡的情况下为大家讲明白大模型的内存占用机制,相信对大家后续训练、使用大模型都非常有帮助。
这是我曾在面试中被问到的问题,为了巩固相关的知识,打算系统的写一篇文章,帮助自己复习备战秋招的同时,希望也能帮到各位小伙伴。
这篇文章将围绕大模型在单卡训练或推理时的显存占用进行系统学习分析,其中有的知识点可能不会涉及太过深入点到为止(因为我也不会),但尽量保证整个读下来逻辑通畅,通俗易懂(只有小白最懂小白!)。
我们都知道:
由此可以明白,一个含有 1G 参数的模型,如果每一个参数都是 32bit(4byte),那么直接加载模型就会占用 4x1G 的显存。
个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:
可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。
符号位都是 1 位(0 表示正,1 表示负),指数位影响浮点数范围,小数位影响精度。
其中 TF32 并不是有 32bit,只有 19bit 不要记错了。BF16 指的是 Brain Float 16,由 Google Brain 团队提出。
我说实话,讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据。
我以 BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:
先给出具体计算公式:
然后 step by step 地分析(不是,怎么还对自己使用上 Cot 了)。
符号位 Sign = 1,代表是负数,指数位 Exponent = 17,中间一坨是:,小数位 Mantissa = 3,后面那一坨是:。
最终结果:三个部分乘起来就是最终结果:-8.004646331359449e-34。
注意事项:中间唯一需要注意的地方就是指数位是的全 0 和全 1 状态是特殊情况,不能用公式。
那么就以目前最常见的混合精度训练方法作为参考,来看一看显存都去哪了。
顾名思义,混合精度训练就是将多种不同的精度数据混合在一起训练,《 MIXED PRECISION TRAINING 》这篇论文里将 FP16 和 FP32 混合,优化器用的是 Adam,如下图所示:
按照训练运行的逻辑来讲:
Step5:重复 Step2 到 Step4 训练,直到模型收敛。
我们可以看到训练过程中显存主要被用在四个模块上:
激活值(FP16)
写到这里,我就有 3 个小问题,第一个问题,为什么不全都用 FP16,那不是计算更快、内存更少?
根据我们第一章的知识,我们可以知道 FP16 精度的范围比 FP32 窄了很多,这就会产生数据溢出和舍入误差两个问题,这会导致梯度消失无法训练,所以我们不能全都用 FP16,还需要 FP32 来进行精度保证。
看到这里你也许会想到可以用 BF16 代替,是的,这也是为什么如今很多训练都是 BF16 的原因,至少 BF16 不会产生数据溢出了,业界的实际使用也反馈出比起精度,大模型更在意范围。
第二个问题,为什么我们只对激活值和梯度进行了半精度优化,却新添加了一个 FP32 精度的模型副本,这样子显存不会更大吗?
答案是不会,激活值和 batch_size 以及 seq_length 相关,实际训练的时候激活值对显存的占用会很大,对于激活值的正向优化大于备份模型参数的负向优化,最终的显存是减少的。
第三个问题,我们知道显存和内存一样,有静态和动态之分别,那么上面提到的哪些是静态哪些是动态呢?
应该很多人都能猜到:
动态:激活值、梯度值
写到这里,我们应该对于分析大模型训练时候的显存问题应该不在话下了(除了动态部分),那么我们就来实测一下,正在阅读的小伙伴也可以先自己尝试计算一下,看看是不是真的懂了。
对于 llama3.1 8B 模型,FP32 和 BF16 混合精度训练,用的是 AdamW 优化器,请问模型训练时占用显存大概为多少?
解:
不考虑激活值的情况下,总显存大约占用 (48 + 16 + 64) = 128G
推理的时候,显存几乎只考虑模型参数本身,除此之外就是现在广泛使用的 KV cache 也会占用显存。
KV cache 与之前讲的如何减少显存不一样,KV cache 的目的是减少延迟,也就是为了推理的速度牺牲显存。
具体 KV cache 是什么我就不展开讲了,我贴一张动图就可以非常清晰地明白了。
记住一点,我们推理就是在不断重复地做”生成下一个 token“的任务,生成当前 token 仅仅与当前的 QKV 和之前所有 KV 有关,那么我们就可以去维护这个 KV 并不断更新。
顺便回答一个很多小白经常会问的问题,为什么没有Q Cache呢?
因为生成当前的 token 只依赖当前的 Q,那为什么生成当前的 token 只依赖当前的 Q 呢,因为 Self-Attention 的公式决定的。
S 代表 Softmax 激活函数:
我们可以看到,在序列 t 的位置,也就是第 t 行,只跟 有关系,也就是说,Attention 的计算公式就决定了我们不需要保存每一步的 Q,再深入地说,矩阵乘法的数学特性决定了我们不需要保存每一步的 Q。
如何计算 KV Cache 的显存是我这篇文章想要关心的事情,先给出公式:
前面的四个参数相乘应该很好理解,就是 KV 对应在模型每一层的所有隐藏向量的总和,第一个 2 指的是 KV 两部分,第二个 2 指的是半精度对应的字节数。
举个栗子,对于 llama7B,hiddensize = 4096,seqlength = 2048 , batchsize = 64,layers = 32 计算得到:
可以看到,KV Cache 在大批量长句子的情况下,显存占用率也是很大的。
68G 看着是相对模型本身很大,但这是在 batch 很大的情况下,在单 batch 下,KV Cache 就仅占有 1G 左右的显存了,就仅仅占用模型参数一半的显存。
什么,你觉得 KV Cache 用的显存还是太多了,不错,对于推理落地侧,再怎么严苛要求也是合理的,MQA 和 GQA 就是被用来进一步减少显存的方法,现在的大模型也几乎都用到了这个方法,我们就来讲一讲。
其实方法不难理解,看这张图一目了然,关键词就是“共享多头 KV”,很朴素的删除模型冗余结构的思路。
最左侧就是最基础的 MHA 多头自注意力,中间的 GQA 就是保留几组 KV 头,右侧 MQA 就是只保留 1 组 KV 头,目前用的比较多的是 GQA,降低显存提速的同时也不会太过于影响性能。
上一小节我们知道 MHA 的 KV Cache 占用显存的计算公式是:
有一个小细节,可以重头开始训练 MQA 和 GQA 的模型,也可以像 GQA 论文里面一样基于开源模型,修改模型结构后继续预训练。目前基本上都是从头开始训练的,因为要保持训练和推理的模型结构一致。
能看到这里的人,我想对于 Lora 的原理应该都很了解了,就浅浅提一下,如下图所示,就是在原来的权重矩阵的旁路新建一对低秩的可训练权重,训练的时候只训练旁路,大大降低了训练的权重数量,参数量 d*d 降为 2*d*r。
有了前面的全参情况下训练的显存分析,现在分析起来就比较通顺了,我们一步一步来,还是以 BF16 半精度模型 Adamw 优化器训练为例子,lora 部分的参数精度也是 BF16,并且设 1 字节模型参数对应的显存大小 φ。
首先是模型权重本身的权重,这个肯定是要加载原始模型和 lora 旁路模型的,因为 lora 部分占比小于 2 个数量级,所以显存分析的时候忽略不计,显存占用 2φ。
然后就是优化器部分,优化器也不需要对原模型进行备份了,因为优化器是针对于需要更新参数的模型权重部分进行处理,也就是说优化器只包含 Lora 模型权重相关的内容,考虑到数量级太小,也忽略不计,故优化器部分占用显存 0φ。
其实容易搞错混淆的部分就是梯度的显存了,我看了不少的博客文章,有说原始模型也要参与反向传播,所以是要占用一份梯度显存的,也有的说原始模型都不更新梯度,肯定只需要 Lora 部分的梯度显存,搞得我头很大。
那么究竟正确答案是哪一种呢,这里直接给出答案,不需要计算原始模型部分的梯度,也基本不占用显存。也就是说梯度部分占用显存也可以近似为 0φ。
总的来说,不考虑激活值的情况下,Lora 微调训练的显存占用只有 2φ,一个 7B 的模型 Lora 训练只需要占用显存大约 14G 左右。
验证一下,我们来看 Llama Factory 里给出训练任务的显存预估表格:
可以看到 7B 模型的 Lora 训练的显存消耗与我们估计得也差不多,同时也还可以复习一下全参训练、混合精度训练的显存分析,也是基本符合我们之前的分析的。
上面 Llama Factory 的那张表也是稍微剧透了一下我们接下来要讲的内容,也就是 QLora,继 Lora 之后也是在业界落地非常广泛通用的一种大模型 PEFT 方法。
QLora,也叫做量化 Lora,顾名思义,也就是进一步压缩模型的精度,然后用 Lora 训练,他的核心思路很好理解,但实际上涉及的知识点细节却并不少。
我同样也不会太过深入地去介绍这个中细节,我主要是想按照显存占用的思路去分析 Qlora,理解思路永远比死的知识点更加重要。
Qlora 来自于《 QLORA: Efficient Finetuning of Quantized LLMs 》这篇论文,实际上这篇论文的核心在于提出了一种新的量化方法,重点在于量化而不是 Lora。
很多不了解的人看到量化 lora 这个名字就以为是对 Lora 部分的参数进行量化,因为他们认为毕竟只有 Lora 部分的参数参与了训练。
但理解了上面一节的小伙伴就明白实际并不是这样,原始模型的本身参数虽然不更新参数,但是仍然需要前向和反向传播,QLora 优化的正是 Lora 里显存占大头的模型参数本身。
那么 Qlora 就是把原始模型参数从 16bit 压缩到 4bit,然后更新这个 4bit 参数吗?
非也非也,这里需要区分两个概念,一个是计算参数,一个是存储参数,计算参数就是在前向、反向传播参与实际计算的参数,存储参数就是不参与计算一开始加载的原始参数。
QLora 的方法就是,加载并且量化 16bit 的模型原始参数为 4bit 作为存储参数,但是在具体需要计算的时候,将该部分的 4bit 参数反量化为 16bit 作为计算参数。
也就是说,QLora 实际上我们训练计算里用到的所有数据的精度都是和 Lora 一样的,只是加载的模型是 4bit,会进行一个反量化到 16bit 的方法,用完即释放。
前面说到的都是模型原始参数本身,不包括 lora 部分的参数,Lora 部分的参数不需要量化,一直都是 16bit。
看到这里机智的你应该也想到了,这比 Lora 多了一个量化反量化的操作,那训练时间是不是会更长,没错一般来讲 Qlora 训练会比 Lora 多用 30% 左右的时间。
基本的思路讲完了,那么其中包含了哪些具体的实现细节呢?
Qlora 主要包括三个创新点,这里我只简单提及,应付面试足够的程度:
优化器分页:为了防止 OOM,可以在 GPU 显存紧张的时候利用 CPU 内存进行加载参数。
想必已经理解 QLora 运行思路的小伙伴,应该可以很轻松的分析出 Qlora 占用显存的部分了吧,这就是理清楚思路的好处。
没错,Qlora 占用的显存主要就是 4Bit 量化后的模型本身也就是 0.5φ,这里没有考虑少量的 Lora 部分的参数和量化计算中可能产生的显存。可以回过头去看看刚才的表格,也是基本符合预期的。
来源:https://zhuanlan.zhihu.com/p/713256008
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-09-18
2024-07-11
2024-07-11
2024-07-26
2024-07-09
2024-06-11
2024-10-20
2024-07-20
2024-07-23
2024-07-12