AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


「Mamba」能干翻 Transformer 的新架构
发布日期:2024-05-10 13:56:18 浏览次数: 2284


最近 Mamba 架构在 LLMs 圈子引起了广泛关注。不少人之言 Mamba 将会把 Transformer 拉下神坛,将取代 Transformer。本文将详细解读下 Mamba 架构。

题外话

Mamba 论文投稿至 ICLR 2024 被拒了:https://openreview.net/forum?id=AL1fq05o7H

不少学者表示震惊:为什么这么好的论文会被拒!

最近「ICML 审稿质量差」好像也成了学术圈热议的话题!


一、Motivation

在介绍Mamba之前,我们先来回顾下 Transformer 结构以及它存在的问题。Transformer 模块如下图所示:

可以分为两种操作:

  • Communication(Attention):Token 之间的交互(如相关性等)
  • Computation(MLP) :Token 内计算

通过 Communication,模型会计算每个 Token 与其他 Token 的关系,来从而得到该 Token 新的表示。这也是 Transformer 具有如此强大能力的主要原因。

在 Training 过程中,表示 Token 之间的关系矩阵可以通过并行计算一次完成,从而极大地加快了 Transformer 的训练速度。

然而在 Inference 阶段,当生成下一个 Token 时,即使之前已经生成了一些 Token,还要重新计算整个序列的 Attention。

生成长度为  的 Token 序列需要  计算。如果序列长度  不断增加,计算成本将显著增加。这种需要重复计算整个序列 Attention 的操作成了 Transformer 架构的一个主要瓶颈。

让我们来回顾下之前 RNN 是如何解决 Inference 阶段的计算瓶颈的。RNN 在每个时刻需要两个输入,即  时刻的输入  和  时刻的隐藏状态 ,来生成  时刻隐藏状态  和预测输出 。如下图所示:

在生成  时,RNN 只需要考虑之前的隐藏状态  和当前的输入,而无需重新计算所有先前的隐藏状态。这使得 RNN 在 Inference 阶段具有很高的效率。

但是由于 RNN 的循环计算(当前隐状态依赖之前的结果)的存在,使得它在训练阶段无法实现并行化,从而无法实现高效训练。

RNN 的问题正好与 Transformer 相反!它的 Inference 速度非常快,但 Training 速度慢。

能否设计一个模型像 Transformer 那样并行训练,同时在 Inference 阶段跟 RNN 那样具有随序列长度线性增长推理速度?这便是 Mamba 的目标。

二、State Space Model (SSM)

与 Transformer 和 RNN 相同,SSM 用于处理序列数据,如文本、信号等。本节将介绍 SSM 的基础知识以及它们如何处理文本数据。

2.1、State Space

状态空间(State Space)包含了完整描述系统的最低数量的变量,是一种通过数学方式来定义系统可能状态的方法。

如下图所示,目标是在迷宫中从当前位置走到出口“Exit”。“状态空间” 可以是指地图中所有可能位置(状态)。

状态空间表示(State Space Representation)是对当前状态的简化描述,显示当前所在的位置(当前状态)、下一步可以去哪里(未来可能的状态)以及哪些变化会将带到下一个状态(向右或向左)。描述状态的变量(如  和  以及到出口的距离)可以表示为状态向量(State Vector)。

在神经网络中,系统的“状态”通常是隐式表示的,如网络中间层向量。

2.2、State Space Model (SSM)

SSM 是用于描述状态表示,并根据输入预测其下一个状态的模型。简单来说,在  时刻,SSM 作用是:根据输入  学习隐式状态表示 ,并预测输出 。如下图所示, SSM 可以表示为:

状态方程(State Equation)用于描述输入  是如何影响(通过矩阵A)后续状态  的:

输出方程(Output Equation)用于描述了如何根据状态  和当前输入  来预测输出:

其中,矩阵 和  可学习的参数。

整个模型可以表示为如下框架:

SSM 的处理过程可以描述为:

  • 输入信号  与矩阵  相乘,得到一个向量,用于表示输入  对系统状态的影响
  • 状态表示(State Representation)是一个隐向量,包含了系统核心"知识"。该状态与矩阵  相乘,描述内部状态之间的关联,用于描述系统的动态特性。在使用状态之前预测输出之前,要先根据当前状态和输入来更新状态
  • 使用矩阵  将状态转换为输出。矩阵 描述了状态与输出之间的关系,即如何将状态映射到输出空间
  • 最后,采用 skip-connection 通过矩阵  来将输入与输出连接起来

最终 SSM 框架可以简化为:

在实际应用中,通常要处理的是离散信号,这时 SSM 可以表示为(图中skip-connection未展示出来):

在  时刻时:

  • 先根据当前输入  和上一前状态  来更新当前状态 
  • 根据当前状态  来计算输出 

有了对 SSM 的描述之后,下面来看下是如何计算状态的。

2.3、Computing with Recurrence

根据上面 SSM 的离散化表示,我们可以使用类似 RNN 的方式将其表示为:

通过 RNN 的方式我们能很方便的实现对模型的计算。但这种方法具有 RRN 的优点(推理快)的同时,还保留了 RNN 的缺点(训练慢)。

2.4、Computing with Convolution

由于线性递归计算可以转化为卷积运算,从而实现并行加速。因此,可以使用卷积来实现对状态的计算。类似在图像识别任务中使用 kernel 来得到聚合特征,可以使用1维卷积 kernel 来在文本序列中实现卷积操作,如下图所示:

根据 SSM 计算公式,可以推导出对应的 kernel 为:

下面来看下上述 kernel 是如何推倒得到的:

首先对状态公式进行展开:

输出则是根据当前状态进行简单的线性映射:

在  时刻,输出  可以表示为:

因此,通过对上式进行系数提取,便得到 SSM 的 kernel :

通过这种方式,SSM 整个输出  只是输入  的卷积: 。这种卷积表示与循环表示等价,不再要求对输入序列进行顺序处理,而通过一个卷积运算使得整个输出  能并行得到,这极大提升了 SSM 模型的训练效率。然而,由于 kernel 大小的限制,它们在推理上不如 RNN 那样快速。

2.5、Combine Computing

上述提到的两种计算方式有其各自的优、缺点,如下图所示:

一个很合理的想法就是,将这两种方法结合起来,实现相互补充:

  • 在 Training 阶段采用 Convolution,实现高效训练
  • 在 Inference 阶段使用 Recurrence,实现高效推理

这样就得到了训练和推理都高效的 Linear time-invariant State Space models:

三、Mamba (the Selective SSM)

在介绍 Mamba 之前,让我们先来看下当前 SSM 存在的问题。

3.1、Problem with SSM

虽然 SSM 在 training 和 inference 计算性能上有不错的表现,但在准确性上表现却比 Transformer 要差些:

主要是因为 SSM 中的矩阵  是静态的,即所有 Token 对应的  和  都是相同的,如下图所示:

这使得它在一些 Transformer 处理起来很简单的任务上表现出较差的效果,比如下图所示:

  • Selective Copying:复制除白色之外的所有输入
  • Induction Heads:概括序列中已识别的模式

为了提升 SSM 准确性问题,Mamba 提出了 Information Selective 机制对输入序列进行选择。此外,Mamba 采用了Hardware-aware 算法来保证计算效率。

3.2、Information Selective

为了解决 SSM 静态矩阵的问题,Mamba 采用了 Information Selective 机制,从而使不同输入能影响矩阵  和 ,即将它们设计为是输入  的函数:

这里  仍然是静态的是为了让状态自己保持静止,不同输入则可以通过  和  来影响状态的变化。

通过这种机制之后,Mamba 具有了能根据不同输入自主选择将哪些内容保留在状态中、忽略哪些内容的能力。

由于 Information Selective 机制的引入,Mamba 不再是线性时不变 SSM了,这也使得它不能再依赖卷积来实现快速训练了。作者采用的并行扫描和硬件感知内存管理算法来保证 Mamba 的训练效率。

3.3、Parallel Associative Scan

并行关联扫描(Parallel Associative Scan)用于执行前缀和操作(Prefix Sum)。前缀和操作是对数字序列进行累积运算的过程。这个操作是“关联”的,即操作中数字的分组方式不会改变结果。

例如,对于序列[1, 2, 3, 4, 5],前缀和操作将生成新的序列[1, 3, 6, 10, 15],其中每个元素是原始序列中该位置之前所有元素的总和。

并行关联扫描是指将前缀和操作在并行计算环境中进行加速的技术。在传统的串行计算中,执行前缀和操作需要进行多次迭代,并且每次迭代都需要等待上一次迭代的结果。而在并行关联扫描中,可以将序列分成多个部分,并同时对这些部分进行部分前缀和操作。然后,通过将每个部分的结果进行组合和调整,最终得到整个序列的前缀和结果。

在 Mamba 模型的上下文中,通过定义关联算子,获得用于并行关联扫描操作的元素和关联算子。最终通过并行运算解决问题,从而降低计算时间。

3.4、Hardware-aware algorithm

GPU 的一个缺点是其较小但高效的 SRAM 与较大但效率稍低的 DRAM 之间的传输 (IO) 速度有限。在 SRAM 和 DRAM 之间频繁复制会成为计算瓶颈。

为了进一步提升 Mamba 的计算效率,作者利用 Nvidia GPU 中 HBM(High Bandwidth Memory)和 SRAM(Static RAM)速度相关的特性,提出了一种硬件感知算法。

将 Mamba 状态计算按照如下方式安排来进一步提高计算速度:

  • SRAM 中保持隐藏状态和矩阵 
  • HBM 中计算  和 
  • 然后将  和  传输到 SRAM 中,在 SRAM 内计算新的隐藏状态
  • 最后将  和  写回 HBM

3.5、Mamba-block

介绍晚 Mamba 关键技术之后,来整体看下 Decoder 的 Mamba-block 结构,如下图所示:

与 Transformer 架构类似,Mamba 架构也是由堆叠的 Mamba 块组成。

3.6、Empirical Evaluation

在原本 SSM 不擅长的 Selective Copying(a)和 Induction Heads(b) 任务中,Mamba 架构表现出了超强的性能。

与其他知名的开源模型相比,Mamba 在多个常见的下游 zero-shot 任务上取得了最好的效果:

在计算效率上,Mamba 比传统的 SSM 实现方式的 Training 速度快了40倍。在 Inference 阶段,Mamba 比相当大小的 Transform 的吞吐量提高了5倍:

Conclusion

本文从 State Space Model (SSM) 基础知识开始,介绍了 Mamba 架构的由来,以及它要解决的问题和效果。本文希望大家能意识到 SSM 研究方向的潜力。

虽然说 Mamba 要替代 Transform 还为时过早,但它提供了一种完全不同的结构,并在某些任务上取得了与 Transform 相当的性能。就这一点,就足以让我们保持持续关注了。

References

[1] Gu A, Goel K, Ré C. Efficiently modeling long sequences with structured state spaces[J]. arXiv preprint arXiv:2111.00396, 2021.

[2] Gu A, Johnson I, Goel K, et al. Combining recurrent, convolutional, and continuous-time models with linear state space layers[J]. Advances in neural information processing systems, 2021, 34: 572-585.

[3] Kola Ayonrinde. Mamba Explained. The Gradient, 2024

[4] Grootendorst. A Visual Guide to Mamba and State Space Models. Exploring Language Models.2024

[5] Loïck Bourdois. Introduction to State Space Models. huggingface. 2024.

[6] James Chen. Mamba No. 5 (A Little Bit Of...). Sparse Notes. 2024.




53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询