微信扫码
与创始人交个朋友
我要投稿
作者:你的真实姓名
知乎:https://www.zhihu.com/question/650979052/answer/3501160453
最近看到知乎一个回答,把千卡训练的难度吹上天了。但其实真正用过千卡就会发现也就那么几个点。于是想写一篇文章简单讲讲。
本文将包括3个部分:首先我们将讨论千卡训练的难题,以及应该在什么时候使用千卡训练;接着,我们将讨论如何在一千张卡上开始训练,如何让他达到近乎线性的性能提升;最后我们将展开讨论一些千卡训练当中仍然悬而未决(至少对于开源社区来说)的问题。
千卡训练和八卡训练的区别是—显卡多了一百多倍。
这意味着什么呢?
这俩问题都很好理解。
时间上,PyTorch内部支持NCCL/Gloo/MPI三个通信后端(请务必使用NCCL。其中AllReduce操作会会根据具体硬件配置走Ring AllReduce和Tree AllReduce。Ring的时间复杂度是,Tree的时间复杂度是 。就算是理论上128节点也比单节点慢至少七倍,实践当中跨节点通讯要远比单节点慢得多。
故障上,一个节点出问题的概率是p,128个节点就是1-(1-p)^128。也就是说如果一个操作在一个训练当中的出错概率是1%,那么在128节点当中的出错概率就是72.37%。
此外,随着规模的增大,许多问题都会变得难以忍受。比如数据增强要花0.1s,一亿条数据就是278个小时(当然这只是胡拆的一个数字,实际有各种机制所以不会有这么大影响。
因此,钱多烧手并不是使用千卡训练的理由。闲得蛋疼可能是,但你得多蛋疼才能想出这么折磨自己的idea?
千卡训练解决的问题是大模型&大数据问题。如果你的训练时间没有超过8192GPU日,那么你绝对不需要一千张显卡。
看到这里,绝大多数人已经可以关掉这篇文章了。除非你的模型和数据都以B(十亿)来作为计量单位。当然如果你正在厕所里手机没电想看点儿东西解闷儿的话(虽然我很怀疑是否会有人把他打出来……那么可以继续往下看
这件事情其实是一个case by case的事情。因为通信、计算速度啥的受硬件影响更多。而每一个集群的硬件拓扑都是不一样的。同样是A100集群,我全DGX节点,每一张A100都是SXM接口并配一块儿专属的IB网卡。你一个小破普惠服务器插8张PCI-E A100,IB卡一个节点只给一张。那咱俩遇到的问题就完全不是一个问题。
因此,要讨论如何提高训练效率、减少训练耗时,我们首先要了解训练耗时在哪里。那么,一个训练步的耗时在哪里呢?需要谨记,没有profile的优化是没有意义的。
你可能会说,forward backward sync。很好,这说明你了解PyTorch的基本流程。不过现实当中要复杂得多。
当然这是可以无限细分下去的,但一般这些就够了。需要注意的是,除了4-7的耗时是真耗时,其他都需要通过异步操作来盖掉。这也是我们的优化目标。
异步执行在PyTorch的dataloader、CUDA和分布式当中都存在。前者可以通过设置num_workers和prefetch_count为0来关闭,后两者可以通过cuda.synchornize和dist.barrier来执行手动同步。在profile时,我们需要首先需要测整个step的时长。然后再在每次测量前执行手动同步来计算每个部分的时长。如果前者的总耗时等于后者4-7的耗时之和,那么通常不需要执行任何操作。但这种情况在千卡操作中几乎不可能发生。
第6步通信往往需要耗费大量时间。因此,我们还需要进一步优化通信。
以下内容是对《PyTorch Distributed: Experiences on Accelerating Data Parallel Training》论文的概括,有感兴趣的同学建议通读并背诵全文。
Paper: https://arxiv.org/abs/2006.15704
在PyTorch当中,梯度的通信和反向传播是交叠进行的。也就是说,每完成一层的梯度计算,都会立即触发当前层的同步。实现起来也很简单,每个进程在完成自己第k层的梯度计算后都会触发一个钩子来给计数器+1s。当计数器达到进程数是开火进行梯度通信。有很多同学在计算梯度过程中遇到过RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.错误,这就是因为有的模块没有参与计算loss,导致梯度同步卡住了。需要注意,当find_unused_parameters=True时,PyTorch分布式使用nn.Module.__init__当中定义sub-module的反向顺序来作为梯度桶的构建顺序。因此,确保模块定义和调用的顺序一致对于高效训练来说很重要。
尽管理论上来说,同步发生的越及时,重合度越高,性能越好。但实际上每次发起通信都是有开销的。因此,现实当中梯度同步并不是越多越好越快越好。为此,PyTorch引入了梯度合桶机制,通过把多个Tensor装在一个桶里再通信桶来减少通信次数从而减少总耗时。合桶的Buffer Size等等参数往往需要针对硬件和模型来调整从而取得最好的通信效果。PyTorch的默认参数是从0.x时代祖传下来的,这一参数通常都需要调节。
当你做完所有操作之后,惊喜的发现TMD怎么同步时间还是单节点的好几倍。这其实是正常情况……实际上超过256卡的训练想要把通信盖掉就是一件不可能的事情。你说老师我看FB论文说他们256卡就是线性提升啊…那这里不得不提的一个策略就是梯度累加了。梯度累加会执行k次forward+backward之后再执行优化器步进。这有很多好处,首先对于大模型batch size通常不能开多大,梯度累加可以提升等效batch size。其次累加期间的backward不需要通信梯度,加快了训练速度。
Python是一种很慢的代码。当然你说JIT trace+torch.compile有提升我也不反对,但对于最高效率来说,只有必须要存在的代码和不存在的代码两种。
抱抱脸的Transformers就是一个反例。两个sub-Module就能写完的TransformerLayer他们硬是能写出来一堆…偏偏他们还信奉Single Model File Policy……我寻思你这完全不考虑继承的封这么多层是要搞鸡毛啊?正例反而是PyTorch……(笑死,我竟然会夸脸书代码写得好。具体来说就是nn.functional当中的各种实现。你会发现他们第一行往往是handle_torch_func。熟悉Python装饰器的小伙汁通常要问了,为啥这里不用个装饰器统一一下?因为装饰器会引入额外的函数调用,额外的函数调用就是额外的开销。
因此,如果你想确保最高的效率,写一个简单的训练代码和模型代码非常重要。毕竟,1%的效率提升,节省的可能是数百个GPU日。
这一段当中中咱们只讨论你能控制的问题。
故障率高的问题其实很好解决。在训练当中,大部分异常都是非致命异常,捉住他们就好了。https://danling.org/utils/decorators/#danling.utils.decorators.catch 是我之前写的一个装饰器,它的作用就是catch异常,然后调回调函数(默认当然就是把错误打印到log里)。所有你需要做的只是使用它来装饰非fatal的操作。
在实际应用当中,我们遇到的最常见的问题是存ckpt写满了磁盘(不准笑,从商汤到深势再到上海AI Lab,这个问题在哪儿都有出现。咱也不知道为啥肯买那么多显卡但不肯多插点儿硬盘,咱也不敢问)。catch住所有保存操作,如果你有闲心可以在回调里删一下之前的ckpt。没嫌心的话…大不了重训一次嘛(逃。第二常见的问题,你猜对了……存log写满了硬盘……所以所有logging操作也都是要catch的。这就是为啥我都用tmux然后开很长的缓存窗口,总是能抢救一些log出来的。
咳咳,说点儿正经的。任何联网操作都是需要catch的,常见的联网操作主要包括从ceph读取数据和…写log到远程(逃。其他就没啥了吧,我见过有大哥尝试恢复OOM的,但效果似乎不是很好,至少我自己没用过。简单来说,唯一不应捕捉的错误是集群炸了。
那有的大兄弟就说了,集群没爆炸,但是有两张卡突然掉了咋办。这个咱第三部分再讨论。
有用过[丹灵]http://danling.org的同学可能比较熟悉。丹灵其他地方都很轻量,唯独实验管理这里写的很复杂。现代丹灵会将创建一个三个级别的实验目录,project/experiment-run/timestamp。其中project是用户给出的,experiment和run分别是通过代码版本和配置计算出来的,timestamp就是运行开始的时间。也就是说,如果代码和配置是完全一样的,丹灵就会认为这是同一个运行。在设置中打开auto_resum就会自动找最新的一个检查点(这就是为啥最后一级要用时间戳)来加载。其实微软用的amlt更好用,他甚至还会创建一个代码的diff文件夹来帮助你回忆当初代码修改了些啥。
模型训着训着发散了几乎是每个训大模型的人都会遇到的问题。输出和loss只要有nan果断丢掉。梯度先clip by value再clip by norm都是常规操作。哦对了,还有初始化……关于大模型收敛性的论文有一堆,此处不再赘述。
实际上当你的训练超过2048个GPU日时,在整个训练过程当中发生单个GPU甚至单个节点下线是再正常不过的事情了。
PyTorch在1.10就引入了torchelastic弹性训练机制,用过的都骂娘。等下,让我先骂一遍,呸。ok咱们继续吧。
我印象当中在微软的最后一轮面试当中被问到了这个问题:如何设计一个弹性分布式系统。
我的回答很教科书。每k分钟,系统会做一次AllReduce来统计存活进程数,然后选举出一个主进程。主进程会计算好每个进程的rank和local rank进行broadcast。所有进程每次forward开始时向主进程发送一个心跳包来汇报状态。主进程会根据心跳包来确定这一个step参与同步的机器有多少
。但很可惜,2024年了。还是没人去写。
我一直认为梯度同步不应该以GPU/进程为单位。而应该分为大同步(节点间同步)和小同步(节点内同步)。小同步可以更高频的进行,大同步则可以更慢的执行。这样不仅能提高实际的梯度同步频率,降低同步总耗时,并且还能天然的去结合小batch和大batch训练的优点—节点内小batch关注个体,节点间大batch关注整体。
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-07-11
2024-07-11
2024-07-09
2024-09-18
2024-06-11
2024-07-23
2024-07-20
2024-07-12
2024-07-26
2024-07-23
2024-11-18
2024-11-16
2024-11-16
2024-10-31
2024-10-31
2024-10-27
2024-10-26
2024-10-25