支持私有化部署
AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


一起聊聊Nvidia Hopper新特性之WGMMA

发布日期:2025-04-06 07:31:17 浏览次数: 1553 作者:机智流
推荐语

深入探索Nvidia Hopper GPU新矩阵乘法操作WGMMA,掌握高效GEMM内核设计。

核心内容:
1. 介绍Hopper架构下Tensor Core的WGMMA指令
2. 分析WGMMA指令对GEMM算法的影响
3. 探讨WGMMA在CUTLASS库中的高效实现

杨芳贤
53A创始人/腾讯云(TVP)最具价值专家

本文翻译自外网资料,译者:企鹅?

原文链接见文末

上次为大家带来了Hopper上的新特性之TMA,这次我们来一起看看Hopper上的新矩阵乘法操作WGMMA。

引子

如果一个 CUDA 教程没有关于通用矩阵乘法(GEMM)的章节,那么就是不完整的。可以说,GEMM 是现代 GPU 上最重要的例程,它在神经网络、大型语言模型和许多图形应用程序中构成了大部分计算。尽管 GEMM 无处不在,但它以难以有效实现而闻名。

这个由三部分组成的教程系列旨在让读者全面了解如何使用 CUTLASS 库在 NVIDIA Hopper GPU 上编写高效的 GEMM 内核。

  • [第 1 部分,即本部分] 讨论了 warp 组矩阵乘法累加(WGMMA)指令。这些是针对基于 Hopper 架构的 NVIDIA GPU 的 Tensor Core 的原始指令。
  • [第 2 部分] 将讨论高效 GEMM 内核的整体设计,包括 CUTLASS 内核中使用的高级技术,如 warp 特化和乒乓调度。
  • [第 3 部分] 将讨论持久内核和 Stream-K,这是一种针对 GEMM 的负载平衡策略,可在大量问题几何形状中实现最先进的效率。

本系列的三个部分大致遵循通用矩阵乘(GEMM)内核的整个开发过程,但采用“由内向外”的方式。首先,我们有按分块进行的 GEMM 基本操作,它调用张量核心(Tensor Cores)来最终进行计算。其次,我们有从每个线程束协同线程组(CTA)角度看到的 GEMM 内核设计——由序言、主循环和尾声组成——其中主要挑战是避免内存加载成为快速张量核心的瓶颈。最后,我们在最外层网格级别对 CTA 进行调度,此时负载平衡考虑因素成为首要问题。

我们希望在阅读完本系列后,读者将成为 GEMM 算法的专家,并能够利用该算法中的一些优秀理念来设计和实现他们自己工作中的其他内核。

Asynchronous Warpgroup MMA (WGMMA)

Hopper 引入了异步线程束组级矩阵乘法累加运算(WGMMA)。一个线程束组由四个连续的线程束组成,即 128 个连续的线程,其中第一个线程束的线程束编号是 4 的倍数。wgmma.mma_async指令由线程束组中的所有 128 个线程共同执行。此操作通常采用以下形式之一,其中矩阵C用作累加器:

  • C = A * B + C
  • C = A * B, where the input from accumulator C is disabled.

WGMMA 的一个显著要求是操作数B必须始终存储在共享内存(SMEM)中。相比之下,操作数A可以位于共享内存或寄存器内存(RMEM)中,并且累加器C始终保存在 RMEM 中。

这篇博客文章的结构如下。首先,我们讨论在 CUTLASS 中调用wgmma.mma_async指令的要点。这涉及构建相关的TiledMMA,以及创建和划分 SMEM 张量以与 WGMMA 兼容。其次,我们讨论确保 WGMMA 正确性所需的同步机制。最后,我们更详细地讨论 WGMMA 中使用的布局,包括来自 SMEM 的操作数的核心矩阵和矩阵描述符的概念。

在整个过程中,为了简洁起见,我们将wgmma.mma_async缩写为wgmma

CUTLASS kernel 中的WGMMA

在本教程中,我们的主要目标是解释用于调用 Hopper Tensor Cores 进行基于分块的 GEMM 的wgmma原语,以及如何将其作为cute::gemm调用的一部分进行调用。为了做好准备,考虑一个标准的 GEMM 内核,它接收维度为MxNxK的输入矩阵A和B,并计算C = A*B。为了并行化计算,内核固定静态分块大小bM、bN和bK,并启动一个由⌈M/bM⌉x⌈N/bN⌉多个线程块(CTAs),每个 CTA 计算输出矩阵的一个bMxbN瓦片rC。这将在被写回到全局C矩阵之前保存在 CTA 的本地内存(RMEM)中。

根据CTA,我们就有了内核的主循环。通过多次迭代,我们循环内部维度,并将A和B的bMxbK和bNxbK块依次从全局加载到共享内存中,作为sA和sB;请注意,在CUTLASS中,我们将sB的形状固定为数学上的转置。(事实上,反映了常见的做法,我们将A和B的块加载到循环ME M缓冲区中,其中的级数由编译时整数给出,例如2或3。然后sA和sB的形状元组的最后一种模式由该阶段计数给出。)cute::gemm调用然后计算sA和sB的(分阶段切片)的乘积,并将值连续累加到rC中。主循环完成后,最后将rC写入全局内存。

现在,我们希望解释以下cute::gemm调用及其参数。

template <class TiledMMA, ... >
__global__ device_gemm(TiledMMA tiled_mma, ...) {
  // PROLOGUE
  // ...
  // Define A/B partitioning and C accumulators
  ThrMMA thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
  Tensor tCsA = thr_mma.partition_A(sA);  // (MMA,MMA_M,MMA_K,PIPE)
  Tensor tCsB = thr_mma.partition_B(sB);  // (MMA,MMA_N,MMA_K,PIPE)
  Tensor tCgC = thr_mma.partition_C(gC);  // (MMA,MMA_M,MMA_N)

  // Allocate accumulators and clear them
  Tensor tCrC = thr_mma.make_fragment_C(tCgC);  // (MMA,MMA_M,MMA_N)
  clear(tCrC);

  // Allocate "fragments"
  Tensor tCrA = thr_mma.make_fragment_A(tCsA);  // (MMA,MMA_M,MMA_K,PIPE)
  Tensor tCrB = thr_mma.make_fragment_B(tCsB);  // (MMA,MMA_N,MMA_K,PIPE)
   
  // PIPELINED MAIN LOOP
while (k_tile_count > -K_PIPE_MAX) {
    // ...
    // MMAs to cover 1 K_TILE
    cute::warpgroup_arrive();
    // (V,M,K) x (V,N,K) => (V,M,N)
    cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);
    cute::warpgroup_commit_batch();
    // Wait for all MMAs in a K_TILE to complete
    cute::warpgroup_wait<0>();
    // ...
  }

  // EPILOGUE
  // ...
}

在 CUTLASS 的 MMA(矩阵乘法累加)范式中,“MMA范式”里的cute::gemm方法旨在通过统一的接口展示特定架构的 MMA 指令。(实际上,如果你查看SM80 教程的 GEMM 内核,你会看到那里的cute::gemm调用在语法上与上述相同。)然而,cute::gemm调用所涉及的参数定义包含许多 WGMMA 特定的方面:

  • TiledMMA对象tiled_mma的定义封装了cute::gemm分派到特定wgmmaPTX指令所需的信息。
  • 必须定义SMEM张量sA和sB的布局以与wgmma兼容。
  • 片段tCrAtCrBtCrC使用TiledMMA对象构建为数据的线程级分区,因此具有程序员应该注意的WGMMA特定布局。
  • 片段tCrA(如果从SMEM获取操作数A)和tCrB不是寄存器支持的张量,其值是从SMEM复制的,而是在SMEM之上构建的矩阵描述符。

最后,当然,在cute::gemm调用周围有线程组同步原语。我们将依次解释所有这些概念。

WGMMA中的TiledMMA对象

以下内容中,假设数据类型为 FP16,且A和B是MN,所以在 BLAS 表示法中,我们正在计算一个 NT gemm。我们使用cute::make_tiled_mma方法在主机上构造TiledMMA对象,如下所示:

TiledMMA tiled_mma = cute::make_tiled_mma(
  SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});

虽然cute::make_tiled_mma也有一些可选参数,但让我们专注于当前的这个参数——矩阵乘法累加原子(MMA Atom)。这是一个结构体,它封装了一个底层的 PTX 调用,在这种情况下是:

wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16

CUTLASS 符号系统使得人们可以立即读出包装后的 PTX 指令与 MMA 原子之间的关系。首先,SM90 是 Hopper 架构的另一个名称。然后,SM90 MMA 原子被标记为SM90_MxNxK_XYZ_SS或SM90_MxNxK_XYZ_RS,其中有两个模板参数可以是GMMA::Major::MNGMMA::Major::K。它们的含义如下:

  • X和Y是操作数的数据类型。
  • Z是累加器的数据类型。
  • MxNxK是wgmma指令计算的大小-“wgmma atom”。并非MxNxK的所有值都是可进行mma操作的的。这是允许形状的列表:M始终为64,N是从8到256的8倍数,对于16位操作数数据类型,K为16(更一般地说,K固定为32字节)。
  • 后缀RS或SS指示操作数A是来自寄存器(R)还是来自共享内存(S)。操作数B始终来自共享内存,因此S。
  • 这两个模板参数指示操作数A和B在MN模式或K模式下是内存连续的。例如,在BLAS表示法中,两个操作数都是K-Major将对应于一个TN gemm(参见此表)。请注意,对于16位操作数数据类型,可以灵活地将内存布局设置为MN-Major或K-Major。但是,对于非16位操作数数据类型,布局必须始终为K-Major。

这就是你需要了解的 MMA Atom 的语法!现在,我们已经强调过 WGMMA 是一个全线程组指令。在代码中,你可以使用其大小来检索参与由 TiledMMA 对象定义的 MMA 操作的线程数量。例如,以下主机代码。

dim3 dimBlock(cute::size(tiled_mma));

规定内核中的每个 CTA 以 1 个包含 128 个线程的线程束组启动。假设我们想要2 个线程束组来执行 WGMMA,由不同的线程束组独立计算输出块的一半(并且每个线程束组发出各自的wgmma指令)。为此,我们可以将一个非平凡的布局(AtomLayoutMNK)作为第二个参数传递给make_tiled_mma方法。例如,以下代码。

TiledMMA tiled_mma = make_tiled_mma(
 SM90_64x64x16_F16F16F16_SS{},
 Layout<Shape<_2,_1,_1>>{});

定义了一个 WGMMA 操作,其中 warp 组 1 和 2 分别计算输出瓦片的上半部分和下半部分,沿M模式划分(现在假设bM是 128 的倍数)。此外,size(tiled_mma)将等于 256。

一般来说,make_tiled_mma的两个可选布局参数——AtomLayoutMNK和PermutationMNK——对于任何 MMA 原子都同样适用。

共享内存的布局约束了WGMMA

接下来,我们解释在给定 MMA 原子选择的情况下,共享内存中操作数矩阵的瓦片大小和布局的约束。首先,对于任何 MMA 指令,MMA 原子的MxNxK需要能够整除操作数和累加器Tile的大小。在我们的例子中,这意味着bM应该是 64 的倍数,bN是 64 的倍数,bK是 16 的倍数。

其次,WGMMA 对sA和sB的共享内存布局(包括形状和跨度)施加了一个额外的约束,并且这个约束会随着所选的交错模式而变化。特别是,(分阶段切片的)sA的布局通常不是简单的(bM,bK):(1,bM)或(bM,bK):(bK,1),sB也是如此。

为了深入理解这些要求,我们需要“核心矩阵”的概念,我们将在下面介绍。然而,实际上,我们总是可以使用 CUTLASS 提供的某些预定义布局原子,然后使用cute::tile_to_shape方法构建保证与兼容的布局。在我们的示例中,我们在主机上准备瓦片大小和如下(其中<T = cutlass::half_t>,这是 CUTLASS 对 FP16 的名称):

auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};  
auto bP = Int<  3>{};  // Pipeline
 
auto sA = cute::tile_to_shape(
    GMMA::Layout_MN_SW128_Atom<T>{},
    cute::make_shape(bM, bK, bP)
);
auto sB = cute::tile_to_shape(
    GMMA::Layout_MN_SW128_Atom<T>{},
    cute::make_shape(bN, bK, bP)
);

在这里,MN表示布局原子适用于MN主操作数,而SW128是 128 字节的交错模式。输出sA或sB会显示。

Sw&lt;3,4,3> o smem_ptr[16b](unset) o ((_64,_2),(_8,_8),_3):((_1,_512),(_64,_1024),_8192)

这个布局是从哪里来的?cute::tile_to_shape采用一个布局(同名的tile)并复制它以平铺在更大的形状上(类似于numpy.tile)。抛开swizzle函数Sw<3,4,3>,我们知道布局原子由(64,8):(1,64)给出,并以列主要方式平铺在形状(128, 64, 3)上,因此对于MxK形状,512的较小外步幅位于M模式,而1024的较大外步幅位于K模式。(8192的最大步幅在于阶段计数P模式,这是有道理的,因为sA或sB的不同阶段切片不应该在内存中混合。)

请注意,64乘以sizeof(half_t)等于128字节,这是swizzle模式的名称。这是设计:由于核心矩阵的工作方式,我们总是在连续方向上安排布局原子的长度以等于swizzle字节数-对于无swizzle,可以是16,或者32、64或128之一。

相对的,如果我们考虑:

auto sA = cute::tile_to_shape(
  GMMA::Layout_K_SW128_Atom<T>{},
  cute::make_shape(bM,bK,bP)
);
auto sB = cute::tile_to_shape(
  GMMA::Layout_K_SW128_Atom<T>{},
  cute::make_shape(bN,bK,bP)
);

打印sA会得到我们预期的结果。

Sw&lt;3,4,3> o smem_ptr[16b](unset) o (_128,_64,_3):(_64,_1,_8192)

由于我们改为在(8,64):(64,1)上平铺(8,64):(64,1)。(请注意,布局((_8,_16),(_64,_1),_3):((_64,_512),(_1,_0),_8192)合并为(_128,_64,_3):(_64,_1,_8192))。

一般来说,我们可以在8种布局原子的可能性中进行选择,它们对应于MN或K为主以及四种混洗模式之一:

  • 无交错:无交错。隐含 16 字节边界。
  • 32 字节交错:交错 2 个连续的 16 字节段。
  • 64 字节交错:交错 4 个连续的 16 字节段。
  • 128 字节交错:交错 8 个连续的 16 字节段。
GMMA::Layout_MN_INTER_Atom<T>
GMMA::Layout_MN_SW32_Atom<T>
GMMA::Layout_MN_SW64_Atom<T>
GMMA::Layout_MN_SW128_Atom<T>
 
GMMA::Layout_K_INTER_Atom<T>
GMMA::Layout_K_SW32_Atom<T>
GMMA::Layout_K_SW64_Atom<T>
GMMA::Layout_K_SW128_Atom<T>

然后,必须将这些布局原子传入tile_to_shape,其中sA和sB的共享内存(SMEM)形状由make_shape(bM,bK,bP)make_shape(bN,bK,bP)给出,形状的模式按照该顺序给出,使得布局原子的分块大小能够整除较大的 SMEM 形状的分块大小。这最终是由混洗模式的选择对 SMEM 形状造成的约束,并且与由矩阵乘法累加(MMA)原子形状施加的另一个约束分开。

WGMMA 片段和描述符

我们创建了TiledMMA对象,并在主机上相应地准备了共享内存(SMEM)布局。现在,在设备上,我们可以使用TiledMMA对象tiled_mma来构建适当的分区张量,以便传递到cute::gemm调用中。首先,我们通过在tiled_mma上调用带有线程索引的get_thread_slice方法来创建一个名为thr_mma的ThrMMA对象。在我们的例子中,线程索引从0到127。

接着,参考上面的内核代码片段,打印张量tCsA和tCsB对于任何线程索引,显示如下:

tCsA: Sw&lt;3,4,3>_smem_ptr[16b](0x7f8800000400) o
    ((_64,(_8,_2)),_2,_4,_3):((_1,(_64,_1024)),_512,_2048,_8192)
tCsB: Sw&lt;3,4,3>_smem_ptr[16b](0x7f880000c400) o
    ((_64,(_8,_2)),_2,_4,_3):((_1,(_64,_1024)),_512,_2048,_8192)

根据注释,tCsA的形状应被视为(MMA,MMA_M,MMA_K,PIPE):

  • MMA是MMA Atom的NxK形状。
  • MMA_M和MMA_K是它在sA的M和K模式上平铺的范围(因此MMA_M=bM/64=2和MMA_K=bK/16=4)。
  • PIPE是stage数。
tCrA: GMMA::DescriptorIterator o (_1,_2,_4,_3):(_0,_64,_256,_1024)
tCrB: GMMA::DescriptorIterator o (_1,_2,_4,_3):(_0,_64,_256,_1024)

在内部,CUTLASS 构造一个“矩阵描述符”,这是一个保存在寄存器中的 64 位值,以一种适合wgmma指令使用的方式描述共享内存(SMEM)。对于程序员来说,最重要的是要记住,共享内存的值不会被复制到寄存器内存(RMEM)中;相反,访问 tCrA 和 tCrB 的值实际上是访问这些 64 位描述符。此外,这些张量作为“迭代器”意味着在任何时候,对于给定的wgmma指令,只有一个 64 位描述符保存在寄存器中(例如,与全部 24 个不同)。

与操作数相比,累加器张量以更标准的方式定义。打印线程 0 的tCgC和tCrC显示:

tCgC: gmem_ptr[16b](0x7f877a780000) o ((_2,_2,_8),_2,_2):((512,_8,4096),_64,32768)
tCrC: ptr[16b](0x7feee1fffbe0) o ((_2,_2,_8),_2,_2):((_1,_2,_4),_32,_64)

tCgC是输出 GMEM 张量的一部分,我们希望在尾声中将累加器的值复制到该部分,而tCrC是为了在主循环中计算这些值时保存这些值而创建的基于寄存器的张量。这些张量的(MMA,MMA_M,MMA_N)形状可以如下解释:在 MMA 原子的MxN=64x64输出块中,128 个线程中的每个线程都持有32=2*2*8个值,并且MMA_M=MMA_N=2与tCsA和tCsB相同。

每个线程以一种需要将 32 分解为(2,2,8)形状的方式持有原子的 32 个值,以便能够为tCgC的布局定义相应的步长。具体的分区模式可以从取自 PTX 文档的这张图片中读出:

这说明了重复的 Z 模式,其中一个线程的 32 个值被保存。例如,线程 0 保存着(0,0)、(0,1)、(8,0)、(8,1)处的值,并每向右 8 列重复一次。

Gemm call

让我们回到上面内核代码片段的第 25 行:

// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);

cute::gemm方法的各种重载首先用于循环遍历外部模式MMA_M/N和MMA_K。一旦选择了这些坐标,我们就使用矩阵乘法累加器原子瓦片形状进行计算。换句话说,我们首先将其简化为针对cute::gemm的调度形状(V)x(V)=>(V)的重载。

然后,代码调用矩阵乘法累加器原子的fma操作(确切地说,在矩阵乘法累加器解包(mma_unpack))。这里包含了一些PTX汇编代码:

CUTE_HOST_DEVICE static void
  fma(uint64_t const& desc_a,
      uint64_t const& desc_b,
      uint32_t& d00, uint32_t& d01, uint32_t& d02, uint32_t& d03,
      uint32_t& d04, uint32_t& d05, uint32_t& d06, uint32_t& d07,
      uint32_t& d08, uint32_t& d09, uint32_t& d10, uint32_t& d11,
      uint32_t& d12, uint32_t& d13, uint32_t& d14, uint32_t& d15,
      GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
  {
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
    asm volatile(
    "{\n"
      ".reg .pred p;\n"
      "setp.ne.b32 p, %18, 0;\n"
      "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
      "{%0,  %1,  %2,  %3,  %4,  %5,  %6,  %7,  "
      " %8,  %9,  %10, %11, %12, %13, %14, %15},"
      " %16,"
      " %17,"
      " p,   %19, %20, %21, %22;\n"
    "}\n"
      : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
        "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
        "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
        "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
      : "l"(desc_a),
        "l"(desc_b),
        "r"(int32_t(scale_D)),
        "n"(int32_t(scaleA)),
        "n"(int32_t(scaleB)),
        "n"(int32_t(tnspA)),
        "n"(int32_t(tnspB)));
#else
    CUTE_INVALID_CONTROL_PATH(
        "Attempting to use SM90_64x64x16_F16F16F16_SS "
        "without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
  }

这种语法对应的 PTX 文档在此处。与上述对张量 tCrA、tCrB 和 tCrC 的描述一致,请注意,对于操作数我们有 uint64 类型的变量 desc_a 和 desc_b,同时对于累加器有 16 个 uint32 类型的变量。scale_D 的值为 0 或 1,它控制累加器是否进行零初始化。

此外,变量 scaleA、scaleB、tnspA 和 tnspB 是在 fma 方法外部通过模板参数在编译时确定的。scaleA 和 scaleB 的值为 1 或 -1,用于对操作数取负;而 tnspA 和 tnspB 表示是否对操作数进行转置,当值为 0 时对应 GMMA::Major::K,值为 1 时对应 GMMA::Major::MN

WGMMA的同步

接下来还需解释围绕 cute::gemm 调用的同步原语:

cute::warpgroup_arrive();
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);
cute::warpgroup_commit_batch();
cute::warpgroup_wait<0>();

为什么这些额外的指令完全有必要呢?这一切都与 wgmma 作为一条异步指令的特性有关。在霍普(Hopper)架构的背景下,“异步” 意味着 wgmma 可以与其他操作并发运行,因此对于有依赖关系的步骤而言,就需要一种同步机制。这种机制在 PTX 内存一致性模型中有详细阐述。代码中如果同步不当,可能会导致以下情况:(a)出现难以察觉的竞态条件,进而引发棘手的错误;(b)编译器会将 wgmma 指令按顺序执行,这可能会导致性能大幅下降;或者(c)出现未定义行为。

cute 方法封装了以下 PTX 指令:

  • cute::warpgroup_arrive() — wgmma.fence.sync.aligned;
  • cute::warpgroup_commit_batch() — wgmma.commit_group.sync.aligned;
  • cute::warpgroup_wait<N>() — wgmma.wait_group.sync.aligned N;

(注意,我们一直用 wgmma 作为 wgmma.mma_async 的简写,但仅在本小节我们会明确区分二者。)让我们把这些指令的用法与从 PTX 文档中逐字引用的以下基于 WGMMA 的通用矩阵乘法(GEMM)描述联系起来:

  • 将矩阵 A、B 和 D 加载到寄存器或共享内存中。
  • wgmma.fence 操作,用于表明整个线程束组的寄存器 / 共享内存已完成写入。
  • fence.proxy.async 操作,使通用代理操作对异步代理可见。
  • 使用 wgmma.mma_async 操作对输入矩阵发起异步矩阵乘法和累加操作。wgmma.mma_async 操作在异步代理中执行。
  • 创建一个 wgmma 组,并使用 wgmma.commit_group 操作将之前所有未完成的 wgmma.mma_async 操作提交到该组。
  • 使用 wgmma.wait_group 等待所需的 wgmma 组完成操作。
  • 一旦 wgmma 组完成操作,所有的 wgmma.mma_async 操作就都已执行完毕。

我们按顺序解释这些要点。首先,wgmma.fence 指令可确保 wgmma.mma_async 仅在对某些寄存器内存(RMEM)地址的所有先前访问完成后,才访问这些地址。如果没有 wgmma.fence,其行为是未定义的。该规则的一个例外是,霍普(Hopper)架构允许同时执行多条 wgmma.mma_async 指令。只要这些 wgmma.mma_async 指令的累加器形状相同,它们就可以共享同一个累加器张量,即写入相同的寄存器内存地址。在这种情况下,就不需要同步(fence)操作。例如,在 cute::gemm 调用中对 MMA_K 进行循环时,我们无需插入 wgmma.fence

与张量内存访问(TMA)操作一样,wgmma.mma_async 是在异步代理中执行的。因此,如果在通用代理中执行的操作会影响 wgmma.mma_async 读取的共享内存(SMEM),我们就需要发出 fence.proxy.async 指令。例如,如果我们通过普通的 ld.global/st.shared 操作将矩阵 A 和 B 复制到共享内存中,就会出现这种情况。由于我们使用了 TMA 加载,在示例中就不需要 fence.proxy.async,实际上,它也未出现在 WGMMA 教程代码或 CUTLASS 霍普架构通用矩阵乘法(GEMM)内核的主循环中。(要验证这一点,请注意 fence.proxy.async 是由 cutlass::arch::fence_view_async_shared() 封装的。)

wgmma.commit_group 指令会为每个线程束组创建一个新的 wgmma 组,并将执行线程束组发起但尚未提交到任何 wgmma 组的所有先前的 wgmma.mma_async 指令批量处理到这个新的 wgmma 组中。在我们的示例中,cute::warpgroup_commit_batch() 会将 MMA_M * MMA_N * MMA_K 条 wgmma.mma_async 指令批量处理到一个 wgmma 组中。

最后,带有参数 N 的 wgmma.wait_group 指令会使执行线程等待,直到最近的 wgmma 组中未完成的数量不超过 N 个,并且执行线程提交的所有先前的 wgmma 组都已完成。在我们的示例中,我们将 N 设为 0,这样线程束组只需等待整个 wgmma 组完成,然后再继续执行后续指令。

在线程束组有机会执行独立计算的情况下,参数 N 的灵活性就派上用场了。例如,在 FlashAttention - 3 的设计中采用的 GEMM - softmax 重叠策略就会用到这一点。

WGMMA核心操作

最后这部分将进一步讨论加载到共享内存(SMEM)中的矩阵 A 和矩阵 B 的分块布局要求,假设 wgmma 的两个操作数均来源于共享内存。为简化讨论,首先假设 A 是按行优先存储,B 是按列优先存储(即两者都是按 K 优先存储)。还要记得,wgmma 指令的分块形状 MxNxK 是受限制的,其中 M 为 64,数据类型大小乘以 K 为 32 字节,N 是 8 的倍数,取值范围从 8 到 256。为避免与 A/B 或 sA/sB 混淆,我们将 WGMMA 的原子分块记为 wA 和 wB。

矩阵 wA 和 wB 被划分为许多较小的矩阵,称为核心矩阵。每个核心矩阵都有一个跨步方向和一个连续方向,其在跨步方向上的长度为 8,在连续方向上的长度为 16 字节。矩阵 wA 由 8x2 的核心矩阵组成,矩阵 wB 由 2x(N/8) 的核心矩阵组成。我们通过核心矩阵来展示 wA 和 wB 的分块情况如下(图片取自 PTX 文档):

如上文所述,处于同步流模式(SS 模式)的 wgmma 需要矩阵描述符,即 wA 的描述符(desc-a)和 wB 的描述符(desc-b)作为输入。这种描述符对五个参数进行了编码:

  • 起始地址:操作数在共享内存(SMEM)中的起始基地址。
  • 首维字节偏移量(LBO,leading dimension byte offset):在 K 维度上,两个相邻核心矩阵之间的字节距离。
  • 跨步步长字节偏移量(SBO,stride dimension byte offset):在 M 或 N 维度上,两个相邻核心矩阵之间的字节距离。
  • 混洗模式:无、32 字节、64 字节或 128 字节。
  • 矩阵基偏移量:当共享内存地址未与混洗模式下重复模式的字节边界对齐时,该偏移量用于解决共享内存的对齐问题。

首维字节偏移量(LBO)和跨步步长字节偏移量(SBO)已在上图中标示出来了。

CUTLASS 中的 make_gmma_desc 方法会根据作为输入提供的共享内存(SMEM)张量的布局来构建描述符(作为 GmmaDescriptor 的一个实例)。只要输入张量的布局是使用八种规范的通用矩阵乘法(GMMA)布局原子之一以及 tile_to_shape 来创建的(如之前在 “WGMMA 的共享内存布局约束” 中详细介绍的那样),make_gmma_desc 就会准确计算出首维字节偏移量(LBO)和跨步步长字节偏移量(SBO),确定混洗模式,并构建出描述符。例如,GmmaDescriptor 描述了在按 K 优先存储的情况下以下可接受的 WGMMA 布局(其中 T*sizeof(dtype)=16):

No swizzle       : Swizzle&lt;0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO))
32-byte swizzle  : Swizzle&lt;1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T ))
64-byte swizzle  : Swizzle&lt;2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T ))
128-byte swizzle : Swizzle&lt;3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T ))

最值得注意的是,对于 64 字节和 128 字节的混洗模式,其步长使得给定的可接受的 WGMMA 布局并非紧凑布局。相反,在 K 方向上会有 2 组或 4 组 WGMMA 原子操作数分块并排堆叠,从而在核心矩阵的 M 模式下产生 4T 和 8T 的步长。换句话说,在混洗时,在内存中会对在 K 模式下逻辑上相邻的 2 个、4 个或 8 个核心矩阵进行交错排列,并且对于 64 字节和 128 字节的混洗模式,这些核心矩阵将属于不同的 WGMMA 原子。

为了内容的完整性,我们也给出在按 MN 优先存储情况下可接受的 WGMMA 布局:

No swizzle       : Swizzle&lt;0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO))
32-byte swizzle  : Swizzle&lt;1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO))
64-byte swizzle  : Swizzle&lt;2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO))
128-byte swizzle : Swizzle&lt;3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO))

总结

在通用矩阵乘法(GEMM)系列的[第一部分]中,我们探讨了在基于(Hopper)架构的 GEMM 中,将线程束组矩阵乘法与累加(WGMMA)作为基本操作时涉及的核心概念。

WGMMA 需要一个由 128 个线程组成的线程束组来协同执行矩阵乘法,并且只能对矩阵的特定片段进行操作。我们深入探讨了其中涉及的特殊形状和布局,着重介绍了如何使用规范的通用矩阵乘法(GMMA)布局 => 分块转换形状(tile_to_shape)模式来构建确保能被 WGMMA 接受的操作数布局。

为了确保其使用行为明确,WGMMA 还需要特定的同步机制。为此,我们解释了 wgmma.fencefence.proxy.asyncwgmma.commit_group 和 wgmma.wait_group 与 wgmma.mma_async 之间的关联及用途。

最后,我们详细解释了 WGMMA 核心矩阵的内部工作原理,以及 CUTLASS 如何为那些源自共享内存(SMEM)的操作数构建矩阵描述符。

总体而言,这篇博客文章应能让程序员在Hopper架构上编写使用 WGMMA 的 CUTLASS 内核。在[第二部分]中,我们将扩展讨论范围,引入张量内存访问(TMA)技术,以及如何在霍普架构的 GEMM 内核中同时使用 TMA 和 WGMMA,从而实现数据复制和计算的重叠操作。


原文链接:https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/


-- 完 --


机智流推荐阅读

1. QCon 全球软件开发大会 | 与全球 140+ 顶尖工程师共同解构 AI 时代的技术浪潮

2. 解读 Fin-R1 | 从数据集构建和训练方法聊聊如何用70亿参数革新复杂金融推理

3. 聊聊大模型推理系统之 LServe:MIT和NVIDIA联合提出革新长序列LLM服务效率的秘密武器

4. AutoGLM 沉思:最流畅的浏览器体验,最“聪明”的多轮思考


欢迎在「机智流」公众号后台回复「cc」,加入机智流大模型交流群;回复「HF」即可加入我们不定期举办的HuggingFace Daily Paper高赞论文分享活动群,也会分享大厂AI论文快讯。与我们一起探索 AI 与人类潜能的未来,一起共赴 AI 浪潮!

53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

添加专属顾问

回到顶部

加载中...

扫码咨询