AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


全网首篇从tensorRT-LLM MoE CUDA kernel角度理解Mixtral-8x7b的推理加速及展望
发布日期:2024-05-19 05:57:10 浏览次数: 2353 来源:AI不止算法


最近,LLM MoE非常火,Mixtral-8x7b是第一个开源的MoE,先说明一下,标题里面的“全网首篇”不能随便加,我一直在调研相关资料,发现确实没有从tensorRT-LLM角度来讲的,所以才加上了哈。

网上已经有了模型结构和官方python源码和huggingface源码对Mixtral-8x7b作解析的优秀文章,解析得都非常好,然而,他们有两个不足点在于:

  • 是没有区分开训练和推理来解析,这导致很多朋友不知道哪部分是训练时候做的,哪部分是推理时候做的,看了后非常混乱,对MoE还是似懂非懂

  • 是多数都是停留在Pytorch python API的角度去解析,不是说不好,只是还比较偏上层,不利于读者深度理解大模型推理引擎层面该怎么样去实现并加速Mixtral-8x7b为代表的MoE模型

本文出于如下原因先讲讲MoE推理相关的内容:

  • 训练比推理不仅内容方面更杂,而且理解起来门槛也更高,推理很适合对MoE建立起80%的认识

  • 本人做推理多一些

已经对MoE比较熟悉的朋友,可以直接划到最后看tensorRT-LLM的实现。

一句话讲明MoE的概念

如果大家对传统机器学习算法比较熟悉,看过李沐的统计学习方法或者吴恩达的机器学习,那么应该听过bagging、boosting等集成学习方法,MoE其实也是集成学习,相较于深度神经网络(Deep Neural Network), MoE更像是宽度神经网络,如下图所示,对于MoE的结果是多个expert的输出进行加权组合得到的,router又叫gating网络,包括一个linear和softmax,起到路由的作用,分发给不同expert权重。所以AAAI22也有一篇paper叫做“Go wider instead of deeper

# moe的pytorch代码import torchimport torch.nn as nn
class Expert(nn.Module): def __init__(self, input_dim):        super(Expert, self).__init__() self.fc = nn.Linear(input_dim, 1) def forward(self, x): return self.fc(x) class MoE(nn.Module): def __init__(self, input_dim, num_experts): super(MoE, self).__init__()        self.experts = nn.ModuleList([Expert(input_dim) for _ in range(num_experts)]) # gating的组成 self.gating = nn.Sequential( nn.Linear(input_dim, num_experts), nn.Softmax(dim=1) )
def forward(self, x): # 各个expert做forward前向推理 expert_outputs = [expert(x) for expert in self.experts] expert_outputs = torch.stack(expert_outputs, dim=1)        # 加权组合各个expert的输出        gating_weights = self.gating(x) final_output = torch.sum(expert_outputs * gating_weights.unsqueeze(2), dim=1) return final_output

Mixtral 8x7B采用的Sparse MoE

随着LLM的迅速发展,参数量急剧膨胀,LLM模型压缩的需求也急剧攀升,其实量化在这之前没有那么火,LLM带来的压缩需求把量化推向了一个新高度。稀疏呢,其实之前也不火,虽然nvidia在A100里面支持sparse tensorcore,但是后来稀疏也并没有进入大众视野,nvidia自己也没有太多推广自己在稀疏方面的工作。这一次随着MoE的稀疏化,一定程度缓解了LLM参数膨胀的问题,我感觉稀疏这个技术会随着LLM的发展也会进入非常多人的视野,得到市场的认可。

observation:当MoE有1000甚至上万个专家,其gating产生的权重将存在非常多近near-zero,此时与稀疏对上了

SparseMoE:在gating稀疏的情况下,只需取topK的gating值对应的expert来计算,最终再把各个expert结果给reduce起来,此时只需要计算少数的expert,这也是为什么在Mixtral-8x7B论文中声称大大减少了计算量,因为实际只需2个7b模型即13b参数参与计算,然后再加上精密的训练技巧,就可比肩LLama2-70b和GPT3.5的任务表现.

Mixtral 8x7SparseMoE pytorch实现

部分实现已经upstream到了huggingface(https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral)

Mixtral 8x7B实际上是Mistral 7B的MoE版本,Mistral 7B与LLama整体来说模型结构类似,有几点不同,具体可以参考https://zhuanlan.zhihu.com/p/684922663的解析。

我们在Mixtral 8x7B的huggingface pytorch实现中,对于推理,只需要关注这么几处变化

# 新增了MixtralBLockSparseTop2MLP替换MLP类,其实和LlamaMLP是一样的# 这也是expert layerclass MixtralBLockSparseTop2MLP(nn.Module):    def __init__(self, config: MixtralConfig):        super().__init__()        self.ffn_dim = config.intermediate_size        self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = nn.SiLU()
def forward(self, hidden_states): y = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) y = self.w2(y) return y # 新增sparse MoE layer class MixtralSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok
# gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
        # 多个MixtralBLockSparseTop2MLP层组成混合专家 self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) \ for _ in range(self.num_experts)])

用图片来表示sparse MoE在Mixtral 8x7B的位置如下:

说白了,MoE FFN替换了原来的FFN,MoE FFN由 gate + 8* experts组成,上文讲到MoE的gating由linear和softmax组成,但在sparse MoE中,还要多一个topK,Mixtral 8x7B的K=2,experts无异,依然就是原来的LlamaMLP或者LlamaFFN,只是参数量或者weight shape和Llama2 70b不一样,要小很多罢了

那么,sparse MoE layer,也就是MixtralSparseMoEBlock,它的forward函数是咋实现的呢?这个关系到我们要写哪些CUDA kernel。

在放出forward函数之前,我们要注意,MoE layer是独立应用在各个token的,所以我们需要收集各个token需要哪些expert来做MoE,换句话说,我们需要收集各个expert需要处理哪些token

# 1.计算gatingtokens = 6x = torch.randn(1, tokens, 128) # 6个tokenhidden_states = xbatch_size, sequence_length, hidden_dim = hidden_states.shapehidden_states = hidden_states.view(-1, hidden_dim)
# 各个expert的权重router_logits = self.gate(hidden_states)# 计算 TopK logits 和 TopK expert idrouting_weights = F.softmax(router_logits, dim=1, dtype=torch.float)routing_weights, selected_experts = torch.topk(routing_weights, \                                               experts.top_k, dim=-1)# 归一化routing_weights /= routing_weights.sum(dim=-1, keepdim=True)routing_weights = routing_weights.to(hidden_states.dtype)# onehot encode the selected experts to create an expert mask# this will be used to easily index which expert is going to be sollicitatedexpert_mask = torch.nn.functional.one_hot(selected_experts, \ num_classes=experts.num_experts).permute(2, 1, 0)# 最终结果的变量定义final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), \ dtype=hidden_states.dtype, device=hidden_states.device)
# 每个expert收集需要计算的token idfor exprt_idx in range(experts.num_experts): expert_layer = experts.experts[expert_idx] print(expert_mask[expert_idx])    # 根据expert mask去到各个expert要计算的token id    idx, top_x = torch.where(expert_mask[expert_idx]) top_x_list = top_x.tolist() idx_list = idx.tolist() # 取对应token id的hidden states current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
    # 将对应token id的hidden states送进expert layer做MLP,然后乘上gating算出来的对应weight current_hidden_states = expert_layer(current_state) \ * routing_weights[top_x_list, idx_list, None]
    # reduce_sum各个expert对某个token的计算结果    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

画一张简单易于理解的图来表示,如下图。对应于上面代码,每个expert都去找自己要处理的token,比如expert1找到token0和token1,expert2找到token2....expert5找到token1,然后各个expert与对应的token hidden states做FFN,得到该token上的中间hidden states,这是什么意思?比如token0与expert1做完FFN之后,需要等待token0与expert3做完FFN之后的值,做一个加权reduce sum(代码里面是index_add)才是token0做完FFN的最终hidden states。

Mixtral 8x7SparseMoE tensorRT-LLM CUDA实现

理清了以上pytorch实现之后,问题来了:上述代码是for循环执行各个expert与对应token的FFN,最后各个token再累加自身的中间hidden states,本质上是一个串行操作。那么如何用CUDA并行实现呢?

tensorRT-LLM对于这块的代码位于https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu,MoElayer的run函数位于923行,流程与普通kernel无异:

  • 分配buffer或者workspace

  • 预防性check

  • launch一系列kernel

    • gating(linear+softmax+topK)

    • expert


难点就在于expert这里如何并行起来,本文单单针对计算逻辑来讲,这里简单来说tensorRT-LLM采用了一种非常straightforward的方法,既然每个expert都是对某几个token做FFN,针对里面的GEMM,那我让各个expert做的GEMM都merge成一个batch,然后做batch GEMM不就好了,batch size的大小等于expert的个数,只需要按照expert的顺序复制以及重排各个token row(每个token表示GEMM左矩阵的一行,后文会解释)并且再记录各个token row对应的原始位置,最后experts做batch GEMM完成之后根据此记录恢复原始的顺序,再加权reduce sum或者加权index add即可得到每个token的hidden states。

针对上文的复制操作、重排操作,主要通过988行和996行的两个CUDA kernel完成,恢复作由1028行这个CUDA kernel完成

画个简单的图来表示一下这个batch GEMM:

可以看到token id按照expert的顺序重排了

GEMM全部都调用cutlass kernel来完成。针对这块相关论文有以下论文推荐部分的第4篇:megablocks。

megablocks的动机在于发现了包括但不限于TensorRT-LLM等框架的MoE  kernel实现的一些限制,涉及到负载不均衡(load-imbalance)引起的padding问题,例如,以上tensorRT-LLM可以merge成一个batch GEMM的前提,就是要做padding,如下图,使得每一个expert都匹配一个相同shape的weight矩阵,以及分配给每个expert的token数量即左矩阵的行数要一样,所以以上那个图,expert1234所负责的token数量都是需要padding或者复制到一个固定值的,这里为了方便理解没有画。

这个后面再来分享。

最后,个人认为大模型LLM训练和推理可能会朝着稀疏的方向逐渐发展,这中间又会牵扯到量化、sparse kernel甚至sparse AI accelerator,总体来说,训练面临的挑战性会更大,这一点随着mixtral-8x7b的开源应该会造成大量高质量paper的产出。

论文推荐

关于MoE paper也非常多,https://zhuanlan.zhihu.com/p/542465517有总结,我这里给大家过滤一下,推荐一些可精读的MoE相关paper

  1. Mixtral of Experts---by mistral.ai

  2. Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity, JMLR'22:针对训练loss func的改进

  3. OUTRAGEOUSLY LARGE NEURAL NETWORKS: THE SPARSELY-GATED MIXTURE-OF-EXPERTS LAYER,ICLR'17: 提出sparse MoE的文章,更加贴近LLM训练和推理特性的MoE,目前主流的MoE结构,也是Mixtral-8x7b的MoE架构

  4. Megablocks: Efficient sparse training with mixture-of-experts ,2022:稀疏大矩阵乘法CUDA实现

  5. DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale:针对MoE的训练和推理的端到端解决方案


53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询