微信扫码
与创始人交个朋友
我要投稿
最近 Mamba 架构在 LLMs 圈子引起了广泛关注。不少人之言 Mamba 将会把 Transformer 拉下神坛,将取代 Transformer。本文将详细解读下 Mamba 架构。
题外话
Mamba 论文投稿至 ICLR 2024 被拒了:https://openreview.net/forum?id=AL1fq05o7H
不少学者表示震惊:为什么这么好的论文会被拒!
最近「ICML 审稿质量差」好像也成了学术圈热议的话题!
在介绍Mamba之前,我们先来回顾下 Transformer 结构以及它存在的问题。Transformer 模块如下图所示:
可以分为两种操作:
通过 Communication,模型会计算每个 Token 与其他 Token 的关系,来从而得到该 Token 新的表示。这也是 Transformer 具有如此强大能力的主要原因。
在 Training 过程中,表示 Token 之间的关系矩阵可以通过并行计算一次完成,从而极大地加快了 Transformer 的训练速度。
然而在 Inference 阶段,当生成下一个 Token 时,即使之前已经生成了一些 Token,还要重新计算整个序列的 Attention。
生成长度为 的 Token 序列需要 计算。如果序列长度 不断增加,计算成本将显著增加。这种需要重复计算整个序列 Attention 的操作成了 Transformer 架构的一个主要瓶颈。
让我们来回顾下之前 RNN 是如何解决 Inference 阶段的计算瓶颈的。RNN 在每个时刻需要两个输入,即 时刻的输入 和 时刻的隐藏状态 ,来生成 时刻隐藏状态 和预测输出 。如下图所示:
在生成 时,RNN 只需要考虑之前的隐藏状态 和当前的输入,而无需重新计算所有先前的隐藏状态。这使得 RNN 在 Inference 阶段具有很高的效率。
但是由于 RNN 的循环计算(当前隐状态依赖之前的结果)的存在,使得它在训练阶段无法实现并行化,从而无法实现高效训练。
RNN 的问题正好与 Transformer 相反!它的 Inference 速度非常快,但 Training 速度慢。
能否设计一个模型像 Transformer 那样并行训练,同时在 Inference 阶段跟 RNN 那样具有随序列长度线性增长推理速度?这便是 Mamba 的目标。
与 Transformer 和 RNN 相同,SSM 用于处理序列数据,如文本、信号等。本节将介绍 SSM 的基础知识以及它们如何处理文本数据。
状态空间(State Space)包含了完整描述系统的最低数量的变量,是一种通过数学方式来定义系统可能状态的方法。
如下图所示,目标是在迷宫中从当前位置走到出口“Exit”。“状态空间” 可以是指地图中所有可能位置(状态)。
状态空间表示(State Space Representation)是对当前状态的简化描述,显示当前所在的位置(当前状态)、下一步可以去哪里(未来可能的状态)以及哪些变化会将带到下一个状态(向右或向左)。描述状态的变量(如 和 以及到出口的距离)可以表示为状态向量(State Vector)。
在神经网络中,系统的“状态”通常是隐式表示的,如网络中间层向量。
SSM 是用于描述状态表示,并根据输入预测其下一个状态的模型。简单来说,在 时刻,SSM 作用是:根据输入 学习隐式状态表示 ,并预测输出 。如下图所示, SSM 可以表示为:
状态方程(State Equation)用于描述输入 是如何影响(通过矩阵A)后续状态 的:
输出方程(Output Equation)用于描述了如何根据状态 和当前输入 来预测输出:
其中,矩阵、、 和 可学习的参数。
整个模型可以表示为如下框架:
SSM 的处理过程可以描述为:
最终 SSM 框架可以简化为:
在实际应用中,通常要处理的是离散信号,这时 SSM 可以表示为(图中skip-connection未展示出来):
在 时刻时:
有了对 SSM 的描述之后,下面来看下是如何计算状态的。
根据上面 SSM 的离散化表示,我们可以使用类似 RNN 的方式将其表示为:
通过 RNN 的方式我们能很方便的实现对模型的计算。但这种方法具有 RRN 的优点(推理快)的同时,还保留了 RNN 的缺点(训练慢)。
由于线性递归计算可以转化为卷积运算,从而实现并行加速。因此,可以使用卷积来实现对状态的计算。类似在图像识别任务中使用 kernel 来得到聚合特征,可以使用1维卷积 kernel 来在文本序列中实现卷积操作,如下图所示:
根据 SSM 计算公式,可以推导出对应的 kernel 为:
下面来看下上述 kernel 是如何推倒得到的:
首先对状态公式进行展开:
输出则是根据当前状态进行简单的线性映射:
在 时刻,输出 可以表示为:
因此,通过对上式进行系数提取,便得到 SSM 的 kernel :
通过这种方式,SSM 整个输出 只是输入 的卷积: 。这种卷积表示与循环表示等价,不再要求对输入序列进行顺序处理,而通过一个卷积运算使得整个输出 能并行得到,这极大提升了 SSM 模型的训练效率。然而,由于 kernel 大小的限制,它们在推理上不如 RNN 那样快速。
上述提到的两种计算方式有其各自的优、缺点,如下图所示:
一个很合理的想法就是,将这两种方法结合起来,实现相互补充:
这样就得到了训练和推理都高效的 Linear time-invariant State Space models:
在介绍 Mamba 之前,让我们先来看下当前 SSM 存在的问题。
虽然 SSM 在 training 和 inference 计算性能上有不错的表现,但在准确性上表现却比 Transformer 要差些:
主要是因为 SSM 中的矩阵 、、 是静态的,即所有 Token 对应的 、 和 都是相同的,如下图所示:
这使得它在一些 Transformer 处理起来很简单的任务上表现出较差的效果,比如下图所示:
为了提升 SSM 准确性问题,Mamba 提出了 Information Selective
机制对输入序列进行选择。此外,Mamba 采用了Hardware-aware
算法来保证计算效率。
为了解决 SSM 静态矩阵的问题,Mamba 采用了 Information Selective 机制,从而使不同输入能影响矩阵 、 和 ,即将它们设计为是输入 的函数:
这里 仍然是静态的是为了让状态自己保持静止,不同输入则可以通过 和 来影响状态的变化。
通过这种机制之后,Mamba 具有了能根据不同输入自主选择将哪些内容保留在状态中、忽略哪些内容的能力。
由于 Information Selective 机制的引入,Mamba 不再是线性时不变 SSM了,这也使得它不能再依赖卷积来实现快速训练了。作者采用的并行扫描和硬件感知内存管理算法来保证 Mamba 的训练效率。
并行关联扫描(Parallel Associative Scan)用于执行前缀和操作(Prefix Sum)。前缀和操作是对数字序列进行累积运算的过程。这个操作是“关联”的,即操作中数字的分组方式不会改变结果。
例如,对于序列[1, 2, 3, 4, 5]
,前缀和操作将生成新的序列[1, 3, 6, 10, 15]
,其中每个元素是原始序列中该位置之前所有元素的总和。
并行关联扫描是指将前缀和操作在并行计算环境中进行加速的技术。在传统的串行计算中,执行前缀和操作需要进行多次迭代,并且每次迭代都需要等待上一次迭代的结果。而在并行关联扫描中,可以将序列分成多个部分,并同时对这些部分进行部分前缀和操作。然后,通过将每个部分的结果进行组合和调整,最终得到整个序列的前缀和结果。
在 Mamba 模型的上下文中,通过定义关联算子,获得用于并行关联扫描操作的元素和关联算子。最终通过并行运算解决问题,从而降低计算时间。
GPU 的一个缺点是其较小但高效的 SRAM 与较大但效率稍低的 DRAM 之间的传输 (IO) 速度有限。在 SRAM 和 DRAM 之间频繁复制会成为计算瓶颈。
为了进一步提升 Mamba 的计算效率,作者利用 Nvidia GPU 中 HBM(High Bandwidth Memory)和 SRAM(Static RAM)速度相关的特性,提出了一种硬件感知算法。
将 Mamba 状态计算按照如下方式安排来进一步提高计算速度:
介绍晚 Mamba 关键技术之后,来整体看下 Decoder 的 Mamba-block 结构,如下图所示:
与 Transformer 架构类似,Mamba 架构也是由堆叠的 Mamba 块组成。
在原本 SSM 不擅长的 Selective Copying(a)和 Induction Heads(b) 任务中,Mamba 架构表现出了超强的性能。
与其他知名的开源模型相比,Mamba 在多个常见的下游 zero-shot 任务上取得了最好的效果:
在计算效率上,Mamba 比传统的 SSM 实现方式的 Training 速度快了40倍。在 Inference 阶段,Mamba 比相当大小的 Transform 的吞吐量提高了5倍:
本文从 State Space Model (SSM) 基础知识开始,介绍了 Mamba 架构的由来,以及它要解决的问题和效果。本文希望大家能意识到 SSM 研究方向的潜力。
虽然说 Mamba 要替代 Transform 还为时过早,但它提供了一种完全不同的结构,并在某些任务上取得了与 Transform 相当的性能。就这一点,就足以让我们保持持续关注了。
[1] Gu A, Goel K, Ré C. Efficiently modeling long sequences with structured state spaces[J]. arXiv preprint arXiv:2111.00396, 2021.
[2] Gu A, Johnson I, Goel K, et al. Combining recurrent, convolutional, and continuous-time models with linear state space layers[J]. Advances in neural information processing systems, 2021, 34: 572-585.
[3] Kola Ayonrinde. Mamba Explained. The Gradient, 2024
[4] Grootendorst. A Visual Guide to Mamba and State Space Models. Exploring Language Models.2024
[5] Loïck Bourdois. Introduction to State Space Models. huggingface. 2024.
[6] James Chen. Mamba No. 5 (A Little Bit Of...). Sparse Notes. 2024.
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-03-30
2024-04-26
2024-05-10
2024-04-12
2024-05-28
2024-05-14
2024-04-25
2024-07-18
2024-04-26
2024-05-06
2024-12-22
2024-12-21
2024-12-21
2024-12-21
2024-12-21
2024-12-20
2024-12-20
2024-12-19