MTraining: Efficient Distributed Training for Ultra-Long Contexts via Dynamic Sparse Attention
MTraining: Efficient Distributed Training for Ultra-Long Contexts via Dynamic Sparse Attention
Wenxuan Li, Chengruidong Zhang, Huiqiang Jiang, Yucheng Li, Yuqing Yang, Lili Qiu
A1 主要贡献
本文针对超长上下文(Ultra-Long Contexts)大语言模型(LLMs)训练中计算成本过高的问题,提出了一种名为 MTraining 的高效分布式训练框架。尽管动态稀疏注意力(Dynamic Sparse Attention, DSA)在推理阶段能有效降低成本,但在分布式训练(特别是涉及 Context Parallelism)场景下,由于 Worker 级(Worker-level) 和 Step 级(Step-level) 的负载不平衡以及通信瓶颈,直接应用 DSA 面临巨大挑战。
本文的主要贡献包括:
1. 动态稀疏训练模式(Dynamic Sparse Training Pattern):基于 RoPE 注意力的理论分析和观察,识别出训练过程中存在 "Vertical-Slash"(垂直-斜线)的稀疏模式,并设计了在线近似预算机制来动态适应这种稀疏性。
2. 平衡稀疏环状注意力(Balanced Sparse Ring Attention):提出了一种基于条带(Stripe-based)的布局设计,有效地解决了分布式环境下的 Worker 级和 Step 级负载不平衡问题。
3. 分层稀疏环状注意力(Hierarchical Sparse Ring Attention):针对异构带宽环境(节点内 vs 节点间),设计了分层通信策略,掩盖了跨节点的通信开销。
实验表明,MTraining 在 32 张 A100 GPU 上成功将 Qwen2.5-3B 模型的上下文长度从 32K 扩展至 512K。在 RULER、PG-19、InfiniteBench 和 Needle In A Haystack 等基准测试中,MTraining 在保持甚至超越基线模型精度的同时,实现了高达 6 倍 的训练吞吐量提升。
图 1:Striped 和 Zigzag Ring Attention 在 4 个 CP Workers (GPUs) 上的工作负载分布。
A3 背景知识/关键Observation/设计原则
3.1 长上下文训练是动态稀疏的
注意力的动态稀疏性在预训练 LLMs 中已有充分记录,而在训练过程中,这种现象变化更为剧烈。
- 观察结果:如 Fig 2b 所示,注意力稀疏度在不同的训练步骤和输入样本之间波动显著。不同的模型检查点(Checkpoints)即使对于相同的输入也会产生不同的稀疏模式,反映了训练的时间动态性。反之,单个检查点对不同输入也会产生多样化的稀疏区域。
- 结论:这些观察强调了在训练期间进行动态稀疏适应的必要性。
图 2:(a) 训练阶段的延迟分解。(b) 不同样本和训练步骤中 128K 上下文的 top-k (k=1024) 注意力召回率。(c-d) 训练期间注意力权重 (c) 及其梯度 (d) 的可视化。结果基于使用 4×8 A100 集群训练的 Qwen2.5-3B。
3.2 注意力训练稀疏性呈现特定模式
基于注意力计算公式,作者推导了注意力权重 ($S = QK^\top / \sqrt{d_k}, A = \text{softmax}(S)$) 及其梯度($Q, K, V$)的关系。
- 梯度依赖性:通过将 $\frac{\partial L}{\partial S}$ 代入注意力的梯度表达式,可以观察到反向传播中的所有矩阵运算(GEMMs)都依赖于注意力权重 $A$。因此,反向传播中的动态稀疏性可以视为前向阶段稀疏性的叠加。
- Vertical-Slash 模式:如 Fig 2c 和 Fig 2d 所示,梯度 $\frac{\partial L}{\partial S}$ 展现出与前向传播紧密镜像的稀疏模式。值得注意的是,反向梯度显示出结构化的稀疏性,在整个训练过程中始终遵循 Vertical-Slash(垂直-斜线) 的局部性模式。作者将此模式的出现归因于相对位置编码(RoPE)的使用(详见附录 A)。
3.3 分布式动态稀疏注意力是不平衡的
分布式动态稀疏注意力引入了单节点设置中不存在的新挑战,最显著的是 Worker 级 和 Step 级 的不平衡。
- Worker 级不平衡:如 Fig 7 所示,动态稀疏性导致不同 Workers 之间的 FLOPs 分布不均。处理较快 Workers 必须在同步屏障处空闲等待,导致不平衡。例如,使用 xAttention 在 95% 稀疏度和 32 路 Context Parallelism 下,不平衡度达到 3.17,将实际加速比降低到理论最大值的三分之一。
图 7:使用 XAttention (Xu et al., 2025) 时,不同 CP Workers 之间的计算不平衡(FLOPs)。不平衡度 = 最大值/平均值。
- Step 级不平衡:指单个 Worker 在不同的 Ring Attention 步骤(Steps)中计算负载的波动。这种波动由变化的稀疏模式和样本复杂性驱动。
- 通信气泡:如 Fig 3 所示,这种变化导致随时间推移的工作负载不均匀。当高稀疏性导致计算量减少到低于通信延迟时,计算和通信难以重叠,从而产生降低性能的“气泡(Bubbles)”。
图 3:Step 级不平衡导致气泡的示意图,此时计算和通信无法重叠。
A2 方法细节
MTraining 旨在加速超长上下文 LLM 的分布式训练,由三个核心组件构成:适应训练期高动态稀疏性的动态稀疏训练模式、解决 Worker/Step 级不平衡的平衡稀疏环状注意力,以及利用异构带宽的分层稀疏环状注意力。
图 4:分布式场景下的 MTraining 概览。
4.1 平衡稀疏环状注意力 (Balanced Sparse Ring Attention)
在全注意力(Full Attention)和因果掩码(Causal Mask)下,Ring Attention 的 ZigZag 和 Striped 实现都能达到负载平衡。但在动态稀疏注意力设置中,它们不同的激活模式导致了严重的不平衡。
- 现有问题分析:如 Fig 5a 和 Fig 8 所示,ZigZag 沿着反核心对角线(Anti-diagonal)跨 Workers 分配计算,并随 Steps 沿对角线移动;而 Striped 则相反,沿对角线分配并沿反核心对角线移动。由于数据依赖的动态稀疏性,这导致了显著的负载不平衡。
主要组件设计:
-
Striped Sparse Ring Attention(条带化稀疏环状注意力):
- 依据:如 §3.2 和 §B.1 所述,RoPE 注意力在训练中主要表现为 Vertical-Slash 稀疏模式。由于 GPU 的块状操作,Slash(斜线)分量在计算中占主导地位。
- 策略:为了平衡跨 Workers 的负载,系统将 Workers 沿着对角线方向对齐,提出了一种 Striped 动态稀疏环状注意力方案。
- 效果:如 Fig 5a 所示,这种设计将 Slash 线均匀分布在各个 Workers 上,允许每个 Worker 在每个步骤中处理连续的 Slash 区域。
-
Block-level Striped Sparse Ring Attention(块级条带稀疏环状注意力):
- 粒度选择:考虑到 Slash 操作的块级计算特性及其空间局部性,引入了块级条带化设计。
- 具体实现:采用 64-token 的条带粒度。
- 优势:这种粒度保持了连贯性,避免了 Token 级条带化带来的碎片化,并维持了 Kernel 的稀疏性和效率。这种对齐还减少了索引开销并提高了运行时性能。
-
Step-level Balanced Ring Attention(Step 级平衡环状注意力):
- 原理:块级条带化设计同时也缓解了 Step 级的不平衡。
- 机制:在超长上下文设置中,Workers 在每个步骤处理细粒度的条带。例如,对于 128 个 Workers 和 512K 序列,每个 Worker 顺序处理 64 个块条带。
- 效果:这种重复的、细粒度的划分稳定了跨步骤的计算,确保了更一致的工作负载分布。
图 5:4 个 CP Workers 下 Striped Ring Attention (a) 和 Hierarchical Striped Ring Attention (b) 的 Step 级计算调度。
4.2 分层平衡稀疏环状注意力 (Hierarchical Balanced Sparse Ring Attention)
Ring Attention 通常通过并发执行矩阵乘法(matmul)和通信 Kernel 来重叠计算与通信。然而,在动态稀疏性下,单 Worker 计算量的减少放大了通信开销,使其成为主要瓶颈。特别是在具有异构通信链路(如 25 GB/s IB HDR vs 300 GB/s NVLink)的分布式训练中,节点间通信往往成为瓶颈。
主要组件设计:
-
Inner- and Outer-Ring Hierarchical Ring Attention(内外环分层环状注意力):
- 结构:将全局环状通信分解为两个层级:内环(Inner Ring) 和 外环(Outer Ring)。内环中,KV 块在每个计算节点内的 $G_{node}$ 个 GPU 间循环;外环则处理 $N_{node}$ 个节点间的聚合 KV 缓冲区交换。
- 调度流程:每个外环步骤包含以下阶段:
- Post Outer P2P:启动非阻塞 P2P 通信,将本地节点的当前 KV 块发送到下一个节点,并发布匹配的接收请求。
- Inner-Ring Attention:在节点间传输进行的同时,GPU 进入长度为 $G_{node}$ 的循环,对节点内的本地 KV 切片执行稀疏环状注意力计算。
- Synchronize:在外环步骤结束时,同步计算和通信,然后进入下一次迭代。
-
Hierarchical Balanced Sparse Ring Attention(分层平衡稀疏环状注意力):
- 挑战:与全注意力不同,在稀疏设置中应用分层环状注意力会改变 KV 块在 Workers 间的传播顺序,可能影响注意力计算模式。
- 解决方案:如 Fig 5b 所示,即使采用了两级 KV 传输(内环和外环),计算在跨步骤时仍然保持对角线对齐。
- 优势:这种设计保留了 Vertical-Slash 模式 并维持了负载平衡。通过将 MTraining 集成此设计,节点间的 KV 传输与内环计算完全重叠,有效地缓解了节点间数据移动带来的通信开销。
附录 B.1 动态稀疏训练模式 (Dynamic Sparse Training Pattern)
受训练期间 Vertical-Slash 模式的观察和理论验证启发(见 §3.2 和 Appendix A),本文将动态稀疏注意力扩展到训练阶段。作者提出了面向训练的动态稀疏模式,包含以下关键组件:
-
Online Budget Approximation(在线预算近似):
- 目的:适应训练步骤和上下文中稀疏模式的动态变化,并消除离线搜索的开销。
- 方法:在观察窗口内跟踪注意力权重统计信息,并在线估计回忆起目标比例注意力质量所需的最小垂直线和斜线的数量。
-
Kernel-Aware Approximation Granularity(Kernel 感知近似粒度):
-
对齐:由于垂直和斜线模式在 Kernel 中以不同粒度运行,近似分辨率与其匹配:
- 垂直线:在 Token 级 进行估计。
- 斜线:在 64x64 块级 进行池化。
-
效果:这种对齐确保了预算估计与实际 Kernel 执行之间的保真度。
-
Algorithm 1: Dynamic Sparse Training Head
该算法描述了动态稀疏训练头的核心逻辑:
- 利用
last_q近似注意力Ab。 - 在线近似垂直预算
kv和 Top-K 索引iv(基于 Token 级)。 - 在线近似斜线预算
ks和 Top-K 索引is(基于 64x64 块级池化)。 - 构建稀疏注意力索引
ivs。 - 执行动态稀疏 Flash-Attention 计算。
A4 实验环境
- 数据集:ProLong 数据集,最大序列长度 512K tokens,包含 1B tokens,训练 1 epoch。所有样本的平均稀疏率为 0.95。
- 模型架构:Qwen2.5-3B,使用 Yarn-extrapolated RoPE(扩展因子 32)。
-
硬件配置:32 张 Nvidia A100 40GB GPU(4个节点,每个节点8张卡)。
- 连接:节点内 NVLink,节点间 InfiniBand HDR。
-
软件配置:
- Context Parallelism (CP):32。
- 并行策略:NNScaler 自动搜索。
- 优化:Zero-2 (with offloading), Gradient Accumulation (64 steps), Gradient Checkpointing。
- 精度:bfloat16 (权重/梯度/激活), float32 (优化器状态)。
- 核心实现:基于 FlashAttention, BlockSparse Attention, PIT 的自定义 CUDA Kernel。
A4 实验结果
1. 长上下文扩展训练 (Long-context Extension Training)
- 实验内容:将 Qwen2.5-3B 从 32K 扩展训练至 512K。
-
结果与分析:
-
Training Loss (Fig 6a):
- Full Attention 在早期 Loss 下降较快,但 MTraining 在后期紧随其后,显示出强大的收敛性。
- MoBA 初始收敛快于 MTraining,但性能随时间退化,最终 Loss 更高。这归因于其粗粒度稀疏索引与细粒度注意力模式的不匹配。
-
Throughput (Fig 6b):
- MTraining 在 512K 上下文长度下实现了高达 6 倍 的端到端训练加速。
- 相比 "Ours w/ ZigZag" 和 "Ours w/o Hierarchical",速度分别提升了 2.1 倍和 1.3 倍。
- 随着 Worker 数量增加,MTraining 实现了近乎线性的吞吐量扩展,而基线方法在分布式设置中性能显著下降。
-
图 6:Qwen2.5-3B 在 ProLong 数据集上进行 512K 上下文窗口持续预训练期间,不同方法的训练 Loss 和吞吐量比较。
2. 长上下文下游任务 (Long-context Downstream Tasks)
- 实验内容:在 RULER, Needle In A Haystack (NIAH), PG-19, InfiniteBench 上评估模型。
- 结果与分析:
- RULER (Table 1):MTraining 在各种上下文长度上始终优于基线。相比 Dense Training,MTraining 实现了 3% 的整体提升(使用 MInference 推理时提升 13.4%)。
- Needle In A Haystack (Fig 9, 10):MTraining 实现了近乎完美的检索性能,优于基线检查点,且计算成本大幅降低。
- Language Modeling (PG-19) (Fig 11):MTraining 在不同上下文长度下保持了与 Dense 基线相当的困惑度(Perplexity)。
- InfiniteBench (Table 2):MTraining 在编码和摘要能力上优于 Dense 基线,并在问答任务上保持了有竞争力的性能。
3. 效率分析 (Efficiency Analysis)
- 负载平衡 (Fig 13, Table 3):
- MTraining 显著降低了动态稀疏注意力中的 Worker 级和 Step 级不平衡。
- 最大计算时间与平均计算时间的比率分别下降了 2.4 倍和 2.3 倍。
- Balanced Sparse Ring Attention 将 Worker 级不平衡降低了 2.1 倍,Step 级降低了 2.2 倍。
- Hierarchical Sparse Ring Attention 进一步将 Worker 级不平衡降低了 1.2 倍,Step 级降低了 1.03 倍。
图 13:在 32 个 GPU 上使用不同方法处理 512K Tokens 时的注意力计算时间分布:(a) 固定 Ring Attention 步骤内跨 CP Workers 的分布,(b) 固定 Worker 跨 Ring Attention 步骤的分布。
A5 结论
本文提出的 MTraining 框架通过解决 Worker 级和 Step 级的负载不平衡问题,成功实现了动态稀疏注意力在分布式设置下的大规模扩展。MTraining 包含三个关键组件:动态稀疏训练模式、平衡稀疏环状注意力和分层稀疏环状注意力。实验证明,MTraining 能够将 Qwen2.5-3B 高效扩展至 512K 上下文窗口,在 32 张 A100 GPU 上实现了高达 6 倍的吞吐量提升,同时在多个长上下文基准测试中保持或提升了模型精度。
A6 附录
A. 理论证明 (Proof of Theory)
A.1 注意力梯度
作者将 Vertical-Slash 模式的出现归因于 RoPE 的使用。定义 $z_{n,m}$ 为位置 $n, m$ 处 RoPE 变换后的查询和键向量的点积。
定理 A.1:应用 RoPE 后,注意力权重的期望仅依赖于相对位置 $n-m$。即 $E[z_{n,m}] = \sum_{i=0}^{d-1} \phi(i){n-m} A_i + \sum B_i$。}^{d-1} \psi(i)_{n-m
基于定理 A.1,得出两个关键见解:
- RoPE 注意力矩阵呈现 Vertical-Slash 覆盖模式。"Slash" 结构源于对相对位置 $n-m$ 的依赖,而 "Vertical" 分量源于 Query/Key 分布中的异常值。
- RoPE 注意力矩阵倾向于形成带状稀疏激活模式。由于系数是位置无关的,激活倾向于在特定的相对位置周围局部聚集。
A.2 定理 3.1 的详细推导
通过定义三角基函数并对 Key 向量建模为随机变量(包含均值部分和波动部分),推导了点积 $z_{n,m}$ 的期望。如 Equation 14 所示,点积期望是 $(n-m)$ 的多个正弦函数的叠加,从理论上支持了 Vertical-Slash 模式的必然性。
Algorithm 2: Balanced Sparse Ring Attention fuse w/ Hierarchical Sparse Ring Attention
该算法提供了融合分层通信的平衡稀疏环状注意力的伪代码:
- 输入包括 World size、Rank、数据 Q/K/V 以及垂直/斜线索引。
- 外环循环:处理跨节点的 P2P 通信(P2Pouter)。发送本地 KV 到下一个外环节点,接收前一个节点的 KV。
- 内环循环:在等待外环通信的同时,处理节点内的 P2P 通信(P2Pinner)和计算。
- 计算核心:调用 block_bar_sparse_attention_forward 执行前向计算,并合并输出。
- 通过这种嵌套循环结构,实现了计算与跨节点通信的重叠。
A7 补充细节
基线实现细节
1. MoBA:将 KV 序列划分为固定大小的块,对每个 Query 使用 MoE 风格的 Gate 选择 Top-K 相关块(始终包含 Query 自身所在的块)。实验中块大小设为 4096,TopK 为 12,512K 上下文下的稀疏率为 0.9。代码被适配以运行 Zigzag Ring Attention。
2. XAttention:通过沿反核心对角线每隔一定步长求和来对块进行评分,仅保留高分块。实验设置块大小为 128,步长 16,阈值 0.9。
ZigZag 调度可视化
Fig 8 补充展示了 ZigZag Ring Attention 的 Step 级计算调度,用于与 Striped Ring Attention 进行对比,直观显示了其在稀疏设置下导致的不平衡模式(沿反核心对角线分布)。
图 8:Zigzag Ring Attention 的 Step 级计算调度。
💬 评论讨论
欢迎在这里分享您的想法和见解!