PROGRAMMING TENSOR CORES: NATIVE VOLTA TENSOR CORE GEMM

Andrew Kerr, Timmy Liu, Mostafa Hagog, Julien Demouth, John Tran (NVIDIA, March 28, 2018)

目录

1. 引言:在 CUDA 中编程张量核心 (Tensor Cores)

本报告主要介绍在 CUDA 中对张量核心进行编程,核心内容围绕以下几点:
- mma.sync: CUDA 10.1 中新增的底层指令,用于直接操作 Volta 张量核心。
- 数据路径供给 (Feeding the Data Path): 如何高效地为张量核心提供数据。
- CUTLASS 1.3: 一个利用原生 Volta 张量核心实现通用矩阵乘法 (GEMM) 的库。

Page 2: 报告内容概览
Page 2: 报告内容概览

2. 张量核心 (Tensor Cores)

张量核心是为混合精度矩阵乘法专门设计的硬件单元,可带来高达 8 倍的性能提升。之前,开发者通过 CUDA 9 中的 WMMA API (Warp-Matrix Multiply-Accumulate) 以一种可移植的抽象层来使用张量核心。

本报告的重点是 mma.sync 指令,这是 CUDA 10.1 中引入的一条新指令,允许开发者直接访问 Volta 张量核心,从而在 Volta SM 架构上实现最大效率。这项技术已在 CUTLASS 1.3 版本中得到应用。

性能测试表明,使用原生张量核心指令的 CUTLASS 1.3 在多种矩阵尺寸下,其性能均能与高度优化的 cuBLAS 库相媲美,甚至在某些情况下有所超越。

Page 3: 张量核心性能对比图
Page 3: 张量核心性能对比图

本次演讲聚焦于 Volta 张量核心,并区分了两种编程接口:
- Warp-synchronous Matrix Multiply Accumulate (WMMA API):为张量核心提供的一个可移植的、更高层次的抽象接口。
- mma.sync:一种底层指令,用于直接访问 Volta 张量核心,是本次讨论的核心。

Page 4: 编程接口对比
Page 4: 编程接口对比

3. VOLTA MMA.SYNC 指令

mma.sync 是一条线程束范围 (warp-scoped) 的矩阵乘法指令,作为 CUDA 10.1 的新特性,它直接面向 Volta 架构的张量核心。

3.1 操作与数据类型

该指令执行矩阵乘加运算: D = A * B + C
- 输入矩阵 AB 为半精度浮点数 (half, FP16)。
- 累加/输出矩阵 CD 可以是单精度浮点数 (float, FP32) 或半精度浮点数 (half, FP16)。

3.2 线程束同步 (Warp-synchronous)

mma.sync 指令是线程束同步的,在一个 warp (32个线程) 内协同执行。单条指令可以执行四个独立的 8x8x4 矩阵乘加操作

Page 6: mma.sync 指令示意图
Page 6: mma.sync 指令示意图

3.3 Warp 内部的分工

一个 warp 被划分为四个 Quad-Pairs (QP),每个 Quad-Pair 由8个线程组成。
- QP0: 线程 T0..T3, T16..T19
- QP1: 线程 T4..T7, T20..T23
- QP2: 线程 T8..T11, T24..T27
- QP3: 线程 T12..T15, T28..T31

每个 Quad-Pair 负责执行一次 8x8x4 的矩阵乘法。

Page 7: Warp 内 Quad-Pair 的划分
Page 7: Warp 内 Quad-Pair 的划分

3.4 组合矩阵乘法

通过在 warp 内复制数据,可以组合这些独立的计算单元来执行更大规模的矩阵乘法。例如,通过精心安排数据,一条 mma.sync 指令可以计算一个 16x16x4 的矩阵乘积。这通过让不同的 Quad-Pairs 处理输入矩阵的不同分块来实现,如下图所示。

Page 8: 组合计算16x16x4矩阵乘积
Page 8: 组合计算16x16x4矩阵乘积

3.5 PTX 语法

mma.sync 指令在 PTX (Parallel Thread Execution) 汇编层面的语法如下:
mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype d, a, b, c;

  • alayout / blayout: 指定输入矩阵 A 和 B 的内存布局,可以是行主序 (.row) 或列主序 (.col)。
  • dtype / ctype: 指定目标矩阵 D 和累加矩阵 C 的数据类型,可以是 .f16.f32
  • d, a, b, c: 分别是目标、A、B、C 矩阵的寄存器操作数。

注意:FP32 元素在传递给指令时需要被打包到 32位寄存器 (.b32) 中。

Page 9: PTX 语法详解
Page 9: PTX 语法详解

3.6 线程与数据映射

3.6.1 FP16 乘数

输入矩阵 A 和 B 的数据(乘数)被分发到一个 Quad-Pair 中的各个线程。下图展示了当输入矩阵 A 和 B 采用不同内存布局(行主序 vs 列主序)时,矩阵元素是如何映射到 Quad-Pair 内特定线程的寄存器中的。这对于程序员准备输入数据至关重要。

Page 10: FP16乘数在线程间的分布
Page 10: FP16乘数在线程间的分布

3.6.2 F16 累加器

本节介绍在使用mma.sync指令进行矩阵乘累加操作时,当累加器数据类型为16位浮点数(F16)时,累加器(Accumulators)如何在线程间分布。此处以一个线程四元组对(Quad Pair)中的线程为例。

执行的指令为:mma.sync.aligned.m8n8k4...
- ctype = {.f16, .f16}:累加器 C 的数据类型为 F16。
- dtype = {.f16, .f16}:目标 D 的数据类型为 F16。
- 每个线程负责的寄存器:
- d: 4 x .f16x2
- c: 4 x .f16x2

这意味着对于一个 m8n8k4 形状的矩阵乘法,当使用 F16 进行累加时,每个线程需要管理4个.f16x2(即8个F16值)的累加器寄存器。下图展示了这种映射关系。

Page 46: F16累加时的线程-数据映射图示
Page 46: F16累加时的线程-数据映射图示

3.6.3 F32 累加器

本节继续探讨mma.sync指令的线程-数据映射,但将累加器的数据类型更改为32位浮点数(F32),这是一种典型的混合精度计算场景。

执行的指令与前一节类似,但累加器类型有所不同:

  • ctype = {.f32, .f32}:累加器 C 的数据类型为 F32。
  • dtype = {.f32, .f32}:目标 D 的数据类型为 F32。
  • 每个线程负责的寄存器:
    • d: 8 x .f32
    • c: 8 x .f32

如果使用 F32 进行累加(以获得更高精度),每个线程需要管理8个.f32类型的累加器寄存器,会占用更多的寄存器资源。下图展示了F32累加器在线程间的分布模式。

Page 47: F32累加时的线程-数据映射图示
Page 47: F32累加时的线程-数据映射图示

4. 数据路径供给 (Feeding the Data Path)

为了充分发挥张量核心的计算能力,必须高效地为其供给数据。这通常通过共享内存 (shared memory) 来实现。

4.1 高效的数据流与关键要求

一个典型的 GEMM 内核数据流如下:

  1. 全局内存 (Global Memory) -> 共享内存 (Shared Memory): 将数据从全局内存加载到片上共享内存中,形成一个线程块级的数据瓦片 (Thread Block Tile)。
  2. 共享内存 (Shared Memory) -> 寄存器文件 (Register File): 从共享内存中将数据加载到每个线程的寄存器中,形成线程级的瓦片 (Thread Tile)。
  3. 寄存器 (Registers) -> 张量核心 (Tensor Cores): mma.sync 指令从寄存器中获取数据进行计算。
  4. 计算结果 -> 全局内存 (Global Memory): 经过可能的后处理(Epilogue),最终结果写回全局内存。
Page 12: GEMM 内核数据流示意图
Page 12: GEMM 内核数据流示意图

总结而言,高效的数据路径供给需要:
- 尽可能使用 128-bit 的内存访问。
- 实现无冲突的共享内存存储操作。
- 实现无冲突的共享内存加载操作。

Page 14: 数据供给的关键要素
Page 14: 数据供给的关键要素

4.2 MMA.SYNC GEMM 的空间交错布局

在使用 mma.sync 实现 GEMM 时,累加器瓦片在逻辑上可能不是连续的。如下图所示,单条 mma.sync 指令计算一个 16x16 的输出瓦片,但在一个更大的 GEMM 计算中,由不同 warp 或不同指令计算出的这些瓦片可以在内存中以空间交错的方式排列,以优化数据局部性或内存访问模式。

Page 15: 空间交错的累加器瓦片
Page 15: 空间交错的累加器瓦片

mma.sync 指令通过空间交错的方式执行通用矩阵乘法(GEMM)。一个线程块(warp)中的线程被划分为多个组,每组线程协同计算结果矩阵的一个子块。下图展示了如何使用4个 mma.sync 指令来计算一个 32x32x4 的矩阵乘法。图中不同的颜色代表了不同的线程组(quad-pairs)。左侧图显示了线程在整个 32x32 目标矩阵上的分布,呈现出一种棋盘式的交错模式。这种空间交错布局旨在优化数据在线程间的共享和访存效率。

Page 16: 空间交错的 mma.sync GEMM 计算示意图。左图展示了整个32x32矩阵的线程分布,右图分解为4个独立的MMA操作。
Page 16: 空间交错的 mma.sync GEMM 计算示意图。左图展示了整个32x32矩阵的线程分布,右图分解为4个独立的MMA操作。

在执行半精度(F16)矩阵乘法时,每个线程负责加载和处理特定的数据片段。对于矩阵 A(列主序),每个线程加载 64 位数据(4个F16元素)。对于矩阵 B(行主序),也采用类似的加载策略。

Page 17: F16乘数的线程-数据映射示意图。
Page 17: F16乘数的线程-数据映射示意图。

为了充分利用内存带宽,访存操作通常以 128 位为单位进行。在空间交错的线程布局下,一次 128 位的内存读取操作可以同时为多个线程提供数据,保证高效的、合并的内存访问。

Page 18: 空间交错布局下的128位内存访问示意图。
Page 18: 空间交错布局下的128位内存访问示意图。

4.3 通过共享内存重排实现无冲突访存

为了给 Tensor Cores 提供持续的数据流,必须尽可能高效地将数据从共享内存(Shared Memory)移动到寄存器(Registers)。为了避免银行冲突 (bank conflicts),需要精心设计内存访问模式。

Page 19: 数据从全局内存到Tensor Cores的流动路径。
Page 19: 数据从全局内存到Tensor Cores的流动路径。

4.3.1 内存布局:从规范到重排

  • 全局内存的规范布局: 在全局内存中,数据通常以规范的(Canonical)布局存储。一个线程束(warp)中的线程会以条带状(striped)的方式访问连续的数据块。
    Page 20: 全局内存中的规范数据布局,线程以条带方式访问。

  • 共享内存的重排布局: 为了实现无冲突的访存,数据在从全局内存加载到共享内存时会进行重排(Permuted)。这种重排布局经过精心设计,使得后续从共享内存到寄存器的加载操作可以实现 128 位的、无冲突的并行访问。
    Page 21: 共享内存中的重排数据布局。

4.3.2 存储到重排共享内存分块

数据转换的核心是从全局内存的列主序布局加载,然后以重排后的布局存储到共享内存中。这个加载和存储的过程分多个阶段完成,每个阶段由 warp 中的一部分线程执行。

Page 22: 从全局内存(上)到重排共享内存(下)的数据加载与存储。
Page 22: 从全局内存(上)到重排共享内存(下)的数据加载与存储。
  • 阶段 1 (Phase 1): 线程 T0-T7 从全局内存加载数据,并交叉存入共享内存的指定位置。
    Page 23: 重排共享内存分块过程 - 阶段1 (T0-T7)。

  • 阶段 2 (Phase 2): 线程 T8-T15 执行类似的操作,填充共享内存的下一批位置。
    Page 24: 重排共享内存分块过程 - 阶段2 (T8-T15)。

  • 阶段 3 (Phase 3): 线程 T16-T23 继续填充。
    Page 25: 重排共享内存分块过程 - 阶段3 (T16-T23)。

  • 阶段 4 (Phase 4): 线程 T24-T31 完成最后的填充。
    Page 26: 重排共享内存分块过程 - 阶段4 (T24-T31)。

通过这四个阶段,整个 warp 协同地将一个数据块从全局内存以合并的方式读出,并以一种优化的、重排过的布局写入共享内存。

4.3.3 指针偏移量计算

为了实现上述重排,需要精确计算每个线程在全局内存和共享内存中的读写地址。
- 全局内存 (列主序): 地址计算相对直接。
- 共享内存 (重排后): 地址计算更为复杂,它引入了异或(XOR)操作来实现重排,从而避免 bank conflict。smem_row 的计算中 (c ^ (r % 4)) 是实现重排的核心。

Page 27: 全局内存和重排共享内存的指针偏移量计算伪代码。
Page 27: 全局内存和重排共享内存的指针偏移量计算伪代码。

4.3.4 从重排共享内存无冲突加载

当数据以重排布局存储在共享内存后,就可以从中高效地加载数据到寄存器,以供给 mma.sync 指令。重排布局确保了当一组线程需要为其 mma 操作加载数据时,它们访问的共享内存地址会分布在不同的 bank 中,从而避免冲突。

Page 28: 数据流再次聚焦于从共享内存到寄存器文件。
Page 28: 数据流再次聚焦于从共享内存到寄存器文件。

该加载过程同样分阶段进行,每个阶段由 warp 中的不同线程组执行,以确保无冲突访问。

  • 阶段 1:
    Page 29: 阶段1,线程 T0-T3 为第一个四线程组(QP0)无冲突地加载数据。
    Page 30: 阶段1,线程 T0-T7 为QP0和QP1无冲突地加载数据。

  • 阶段 2, 3, 4: 后续阶段由 warp 中其余线程(T8-T15, T16-T23, T24-T31)执行类似的无冲突加载操作。
    Page 31: 无冲突共享内存加载 - 阶段2
    Page 32: 无冲突共享内存加载 - 阶段3
    Page 33: 无冲突共享内存加载 - 阶段4

5. CUTLASS 1.3

CUTLASS是一个用于深度学习的CUDA C++模板库,旨在为GEMM(通用矩阵乘法)等计算提供高度优化的可复用组件。
- 分块结构 (Blocked structure):最大化数据复用。
- 软件流水线 (Software pipelined):隐藏访存延迟。
- 无冲突共享内存访问 (Conflict-free Shared Memory access):最大化数据吞吐量。

完整的计算流程包括:从全局内存加载数据到共享内存,再到寄存器,通过Tensor Cores计算,最后通过Epilogue函数处理结果并写回全局内存。

Page 36: CUTLASS GEMM计算流程概览
Page 36: CUTLASS GEMM计算流程概览

5.1 针对Volta Tensor Cores的可复用组件

CUTLASS 1.3 提供了模块化的组件,这些组件可以组合起来,高效地利用Volta架构的Tensor Cores。整个流程被分解为几个关键部分:
1. GlobalLoadStream: 从全局内存加载数据。
2. Warp Matrix Multiply: 在Warp级别执行矩阵乘法。
3. Epilogue: 处理计算结果并写回。

这些部分又由更细粒度的组件构成,如 GlobalLoadIterator, SharedStoreIterator, WarpMatrixMultiply, SharedLoadTransformer 等。

Page 37: CUTLASS 1.3 的可复用组件
Page 37: CUTLASS 1.3 的可复用组件

5.2 CUTLASS 实现流程

5.2.1 存储到共享内存 (Storing to Shared Memory)

使用CUTLASS的Tile迭代器,将数据从全局内存加载并转换到共享内存中。
- 转换: 将全局内存中的列主序(Column-major)矩阵布局,转换为共享内存中的置换布局(permuted shared memory layout),以便后续的无冲突访问。

Page 38: 从全局内存存储到共享内存的流程与代码示例
Page 38: 从全局内存存储到共享内存的流程与代码示例

5.2.2 从共享内存加载 (Loading from Shared Memory)

同样使用Tile迭代器,将数据从共享内存加载到寄存器文件。
- 转换: 将共享内存中的置换布局,转换为寄存器文件中符合mma.sync指令要求的线程数据映射(mma.sync thread-data mapping)。

Page 39: 从共享内存加载到寄存器的流程与代码示例
Page 39: 从共享内存加载到寄存器的流程与代码示例

5.2.3 执行 MMA.SYNC (Executing MMA.SYNC)

调用mma.sync指令,在Warp级别执行矩阵乘法累加操作。
- 转换: 将寄存器文件中符合mma.sync要求的数据映射,送入Tensor Cores执行计算。

Page 40: 执行mma.sync指令的流程与代码示例
Page 40: 执行mma.sync指令的流程与代码示例

5.3 性能评估:与WMMA的加速比

在Volta V100 GPU和CUDA 10.1环境下,针对Transformer模型的不同问题规模,比较了CUTLASS 1.3中mma.sync实现相对于WMMA API的性能加速比。结果显示,CUTLASS 1.3在各种场景下均表现出显著的性能提升,加速比范围从约1.1倍到1.7倍以上。

Page 41: CUTLASS 1.3 mma.sync 相对于 WMMA 的加速比
Page 41: CUTLASS 1.3 mma.sync 相对于 WMMA 的加速比

6. 结论

  • Volta Tensor Cores的可编程性:在CUDA 10.1中,Volta Tensor Cores可以通过mma.sync指令直接编程,这补充了原有的WMMA API。

  • CUTLASS 1.3

    • 作为一个为深度学习设计的CUDA C++模板库,它提供了针对Volta Tensor Cores的可复用组件。
    • 实现了从置换布局的共享内存中进行高效存储和加载。
    • 为优化输出的epilogue提供了高效的工作空间。
  • 新内核 (New kernels)

    • 提供了针对Tensor Cores的实数和复数混合精度GEMM。
    • mma.sync GEMM实现了并行化归约(Parallelized reductions)。

项目地址: https://github.com/NVIDIA/cutlass

Page 42: 结论要点
Page 42: 结论要点

7. 参考文献

Page 43: 参考文献列表
Page 43: 参考文献列表