AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


大模型:训练时GPU显存不足怎么办
发布日期:2024-05-03 13:10:08 浏览次数: 2168


前言

大模型时代对显存的要求越来越高,之前在BERT刚诞生时候写过一篇:GPU 显存不足怎么办?,新的这篇文章主要是重构之前的文章,来聊聊大模型时代显存不足时怎么办,没有看过的朋友直接看这篇即可。

 

训练时显存占用分析

训练模型时所占用的显存主要分为以下部分:模型权重参数优化器状态,梯度激活值。假定模型本身的大小为 A,且以 fp32 为精度计算。


模型权重参数

在模型显存为 A 的情况下,所占用的显存为:
  • fp32 精度显存占用:4A
  • 混合精度下显存占用(bf16/fp16):2A

优化器状态与梯度

  • SGD为例,其计算公式为:
        
    我们看到在 SGD 中,那么此时的显存占用只有梯度:
  • Momentum-SGD 为例,其计算公式为:
    我们看到在Momentum-SGD 中,不仅仅有梯度 ,还有动量
  • Adam 为例,其计算公式为:

    

   我们看到在 Adam 中,需要保存的包括:当前梯度,梯度加权平均,梯度平方的加权平均

因此,假定模型大小为 A,训练中采用 FP32 精度进行优化,那么此时优化器状态和梯度占用的显存分别为:
  • SGD:优化器状态:0, 梯度:4A
  • Momentum-SGD:优化器状态:4A,梯度:4A
  • Adam优化器状态:8A,梯度:4A
而在实际的训练中往往采用混合精度训练,而在混合精度训练下的显存又有所区别。


激活值

激活值的显存占用与 token长度per_gpu_batch_sizehidden_size 以及 transformer层数 正相关,并且占用显存也非常大,此处就不细写了,主要是技术很复杂,我也没算明白,哈哈哈哈。

 

训练时显存不足怎么办?

下面列出一些常见的节省显存的操作,优先级从高到低排列
  • 去掉compute_metrics:有些代码会在输出层后计算rouge分等,这个会输出一个batch_size*vocab_size*seq_len 的一个大向量,非常占显存。
  • 采用bf16/fp16进行混合精度训练:现在大模型基本上都采用 bf16 来进行训练,但是如v100这些机器不支持,可以采用fp16进行训练。显存占用能够降低一倍。
  • Flash attention:不仅能够降低显存,更能提高训练速度。
  • 降低你的batch size:如上文所述,batch size 与模型每层的激活状态所占显存呈正相关,降低batch size 能够很大程度上降低这部分显存占用。
  • 采用梯度累积:global batch size = batch size * 梯度累积,如果降低 batch size 后想保持你的 global batch size 不变,可以适当提高梯度累积值。
  • 选择合适的上下文长度:如上文所述,上下文长度与激活状态所占显存呈正相关,因此可以通过适当降低上下文长度来降低显存占用。
  • DeepSpeed Zero:显存占用从高到低为Zero 1 > Zero 2 > Zero 2 + offload > zero 3 > zero 3 + offload,推荐最多试到 Zero2 + offload
  • 选择更小的基座模型:在满足需求的情况下,尽量选择更小的基座模型。
 
几个慎重选择的操作:
  • Lora:能跑全参就别跑 Lora 或 Qlora,一方面是麻烦,另一方面的确是效果差点。
  • Qlora:Qlora 的速度比lora慢,但所需显存更少,实在没资源可以试试。
  • Megatron-LM:可以采用流水线并行张量并行,使用比较麻烦,适合喜欢折腾的同学。
  • Pai-Megatron-LM:Megatron-LM 的衍生,支持 Qwen 的sft和pt,坑比较多,爱折腾可以试试。
  • 激活检查点:不推荐,非常耗时。在反向传播时重新计算深度神经网络的中间值。用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)

 

最后

ok,本文到此就结束了,本文主要是对之前文章进行了细化,并补充了大模型时代下的几种显存不足时的方法。
大模型时代来了,乞丐玩家是不是更多了啊,同学。
微信公众号不支持公式,太费劲了。

欢迎大家关注我的微信公众号:老宋聊AI。

 

参考

【1】https://zhuanlan.zhihu.com/p/31558973




53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询