Tree Training: Accelerating Agentic LLMs Training via Shared Prefix Reuse
Tree Training: Accelerating Agentic LLMs Training via Shared Prefix Reuse
Shaojie Wang * † 1 Jinghui Wang * † 1 Yinghan Cui * 1 Xuxing Chen * 1 Chao Wang * 1 Liang Huang 1 Xiaojiang Zhang 1 Junyi Peng 1 Li Wan 1 Haotian Zhang 1 Bin Chen 1
A1 主要贡献
本文针对Agentic LLM(代理式大语言模型)训练中普遍存在的多轮交互和分支路径问题,提出了一种高效的训练框架——Tree Training。主要贡献如下:
- 核心问题识别:作者指出Agentic LLM的训练轨迹(Trajectory)由于并发工具调用、Think-mode(思维模式)、子代理(Sub-agents)等设计,天然形成带有共享前缀的树状结构(Tree-structured),而非简单的线性序列。现有的训练流程通常将这些轨迹线性化并独立处理每个分支,导致前向和后向传播中存在大量冗余计算。
- Gradient Restoration(梯度恢复):这是本文的核心创新点。不同于仅适用于推理的KV Cache,作者提出了一种在后向传播中消除冗余前缀计算的方法。该方法允许每个共享前缀仅计算一次,通过对梯度进行特定的补偿(Scaling),使得最终聚合的梯度在数学上严格等价于对所有分支独立训练的结果,且开销极低。
- Tree Packing(树打包):为了解决单个轨迹树可能超过GPU显存限制的问题,作者重新设计了训练引擎以支持树状数据输入,并提出了一种基于启发式DFS的内存高效分区策略(Tree Packing)。该策略将大树分割为多个子树,在满足显存约束的同时最大化前缀重用率。
- 显著的性能提升:在密集模型(Dense)和混合专家模型(MoE)上的实验表明,该方法在监督微调(SFT)和强化学习(RL)的模型更新阶段,均能实现高达 6.2倍 的端到端训练加速,且不损失模型精度。
A3 背景知识/关键Observation/设计原则
前向与后向传播的不对称性
在自回归(Autoregressive)LLM中,前向传播(Forward Pass)利用因果掩码(Causal Mask),即输出 $O_i$ 仅依赖于当前及之前的Token。这使得推理时的 Prefix Caching(前缀缓存)成为可能:相同的字首(Prefix)产生相同的Key/Value状态,可以被复用。
然而,训练时的后向传播(Backward Pass)呈现“转置”的因果关系:
* 前向传播:
由于因果掩码,注意力矩阵 $P$ 是下三角矩阵。相同的前缀在不同序列中具有相同的 $P$ 和 $V$,因此输出 $O$ 相同。
* 后向传播:
关键Observation:即使两个序列拥有完全相同的前缀,由于后缀不同,反向传播回前缀的梯度 $dV_{prefix}$ 也会不同(如图2所示)。传统的缓存方法若要支持训练,必须存储所有分支的后缀信息,这将导致显存消耗爆炸,在实践中不可行。因此,必须寻找一种数学上等价但无需存储所有后缀状态的方法来聚合梯度。
A2 方法细节
3.2 梯度恢复 (Gradient Restoration)
基线方法 vs. 树训练方法
* 基线方法 (Baseline):通常的做法是将树状轨迹中的每条从叶节点到根节点的路径回溯,构建成独立的序列(List-like tensors)。这种线性化处理导致共享前缀在每个序列中被重复计算。
* 本文方法 (Our Method):采用树状序列化,将共享前缀 $token(i)$ 仅保留一份,并被所有子节点复用。
梯度更新的等价性推导
为了证明树训练的有效性,必须保证两种方法计算出的梯度是等价的。考虑一个根节点带有 $n$ 个子节点的简化2层子树情况。
对于线性变换 $Y = X \times weight$,权重梯度为 $dweight = X^T \times dY$。基线方法将前缀 $P$ 与每个后缀 $S_i$ 拼接,而本文方法仅保留一个 $P$。
为了使梯度贡献等价,必须满足:
这意味着,前缀 $P$ 在我们方法中对 $dweight$ 的贡献,必须等于它在基线方法中所有对应前缀贡献的总和。
算法分析与证明
为了保证参数更新的等价性,必须满足两个条件:
1. 梯度聚合等价:
针对Transformer中的不同操作,作者分别进行了论证:
* 线性操作 (Linear Operation):
线性操作是逐点(pointwise)的,即 $dX = dY \times weight^T$。每个 $dX_i$ 仅依赖于 $dY_i$。因此,只要输入的梯度 $dY$ 满足聚合条件(即前缀的 $dY$ 是所有分支 $dY$ 的总和),线性操作就能保持这种梯度修正的传递性。这意味着我们只需在反向传播开始前引入一个 梯度缩放 (Gradient Scaling) 步骤,即可恢复正确的模型更新。
-
注意力操作 (Attention Operation):
$$\begin{aligned} \begin{aligned} dO_{P}^{ours} = dO_{p_{1}}^{base} + dO_{p_{2}}^{base} + ... + dO_{p_{n}}^{base} \\ dO_{S_{i}}^{ours} = dO_{S_{i}}^{base}, \forall i \in [1, n] \end{aligned} \end{aligned}$$
虽然Attention包含复杂的矩阵乘法,但作者证明梯度修正同样具有传递性。如图3所示,只要满足 $dO^{ours}_P = \sum dO^{base}_{pi}$ 和 $dV^{ours}_P = \sum dV^{base}_{pi}$,即可推导出前缀的 $dV$ 梯度也满足求和关系。
$$dV_{P}^{ours} = dV_{p_{1}}^{base} + dV_{p_{2}}^{base} + ... + dV_{p_{n}}^{base}$$$$dV_{S_{i}}^{ours} = dV_{S_{i}}^{base}, \forall i \in [1, n]$$ -
其他操作 (如RoPE):
$$dX_i = dY_i * \frac{\partial Y_i}{\partial X_i}$$
对于像RoPE这样依赖位置的操作 $Y = Rope(X, m)$,梯度依赖于位置 $m$。因此,必须确保树打包后的前缀使用的 Position ID 与原始基线方法中的一致。这通过解耦位置编码与物理存储位置来实现。
$$\frac{\partial Y_{i}^{ours}}{\partial X_{i}^{ours}}=\frac{\partial Y_{i}^{base}}{\partial X_{i}^{base}}$$
实现细节
基于上述理论,实现包括三个关键组件(如图6所示):
1. 共享前缀注意力掩码 (Shared Prefix Attention Mask):在前向传播中,引入一种修改后的因果掩码。该掩码限制每个Token的注意力范围,确保不同轨迹的Token可以安全地共享前缀表示,而不会发生信息泄露(即不同分支间不可见)。作者基于 Flash Attention V3 【Shah et al., Flashattention-3: Fast and accurate attention with asynchrony and low-precision, 2024, arXiv】 实现了支持节点级共享前缀掩码的高性能GPU内核。
2. 位置嵌入 (Position Embedding):树打包后的数据物理位置发生了变化。为了保持一致性,必须恢复原始(打包前)的Position IDs。例如,某分支的Token在树结构中可能紧跟在前缀之后,其Position ID应接续前缀的ID,而非其在打包内存中的索引。
3. 梯度缩放器 (Gradient Scaler):这是最核心的组件。作者计算每个节点在树中被复用的次数(称为 tree-scale)。在反向传播开始前,将每个节点的梯度乘以对应的 tree-scale 因子。
* 工作流:如图7所示,如果一个前缀节点被5条轨迹复用,其梯度会被乘以5。这在数学上等价于将该前缀独立计算5次并累加梯度。
* 并行性:Tree Training与现有的并行策略(TP/EP/DP/PP)正交,可无缝结合。对于上下文并行(Context Parallelism),只需根据查询分片生成对应的注意力掩码即可。
* MoE负载均衡:对于有辅助损失(Auxiliary Loss)的MoE模型(如Qwen3 MoE),通过在计算Router辅助损失时将前缀Token乘以其共享计数(Gradient Scaler),即可实现与基线的数学等价。
3.3 树打包 (Tree Packing)
问题与策略
在实际训练中,完整的轨迹树可能过大无法放入GPU显存。因此,需要一种打包算法将大计算树分割为一系列满足显存限制 $C$ 的子树,同时最大化前缀共享。
理论上的最优划分需要动态规划结合装箱问题(Bin Packing),属于NP-hard问题,对于大规模树计算成本过高。
启发式DFS算法
作者采用了一种贪心启发式DFS(深度优先搜索)算法,该算法随树的大小线性扩展,能有效逼近最优解。其原则包括:
- 优先分配最深叶节点:因为它们对总轨迹长度贡献最大。
- 分组:将同一子树中深度相似的叶节点组合在一起,提高打包的同质性。
- DFS遍历:按深度优先顺序遍历树,一旦累积长度超过容量 $C$,则启动新的遍历(生成一个新的Packed Sequence)。
效果示例:如图4所示,一个包含4条轨迹、总计83k tokens的树,若限制显存为60k tokens。基线线性化方法会产生164k tokens。而Tree Packing将其分为两个序列,总计仅102k tokens,显著减少了冗余。
A4 实验环境
-
数据集:
- 真实Agentic轨迹:收集自Terminus 【Terminus, https://www.tbench.ai/terminus 】 和 Claude code 【claude-code, https://claude.com/product/claude-code 】 的多轮RL Rollout数据。包含并发工具执行、Retokenization Drift以及Think-mode(思维模式)产生的树状结构。
- 合成数据集:为了控制变量,构建了不同潜在重叠率(POR)的合成数据,POR范围从20%到92%。
-
模型架构:
- Dense模型:Qwen3-32B。
- MoE模型:Qwen3-30B(Mixture-of-Experts)。
-
硬件配置:
- 集群包含64个NVIDIA Hopper GPU。
-
软件配置:
- 基于 Megatron-Core 【Shoeybi et al., 2020】 框架进行分布式训练。
- 对比基线使用标准的 Sequence Packing 【Krell et al., 2021】 策略。
A4 实验结果
1. 性能指标定义
- 潜在重叠率 (POR):理论上的计算重用上限。定义为 $POR = 1 - \frac{N_{tree}}{N_{X_{base}}}$,即树状结构Token总数与基线线性化后Token总数的比率。
- 有效重用率 (ERR):在显存约束下实际消除的冗余比例。定义为 $ERR = 1 - \frac{N_{pack}}{N_{X_{base}}}$。
2. 真实场景下的加速与正确性
* 加速效果:在包含Think-mode的真实Rollout数据上,理论加速上限为6.5倍。Tree Training在32B Dense模型上实现了 6.3倍 加速,在30B MoE模型上实现了 6.2倍 加速,达到了理论上限的95%以上。
* 正确性验证:如图8(下半部分)所示,Tree Training的训练Loss曲线与基线完全重合,相对误差始终保持在极低水平(<1%),证实了方法的数学等价性。
3. 不同POR下的加速表现
* 如图9所示,随着数据集POR(重叠率)的增加,Tree Training带来的加速比单调递增。
* 在理想情况下(全树可放入显存),加速比最高可达 8.7倍。
* 即使在需要切分大树的情况下(图9b),加速效果依然显著。
4. 显存开销
* 如表1所示,引入的额外组件(Attention Masks, Position IDs, Gradient Scalers)带来的显存开销极低。例如对于32B模型,这些组件仅占用不到1MB的显存,相对于高达64GB的基线激活显存需求几乎可以忽略不计。
5. 下游任务性能提升
* 在Terminal Bench 2.0基准测试中,使用开启Think-mode的全树训练(Full Tree Training)相比于仅使用最长轨迹(Single Trajectory)训练,模型得分从 20.9 提升至 28.8,证明了利用树状轨迹中所有分支数据的有效性。
A5 结论
本文提出的 Tree Training 框架通过 Gradient Restoration 和 Tree Packing 技术,成功解决了Agentic LLM训练中因树状轨迹线性化导致的计算冗余问题。该方法在数学上严格保证了梯度更新的正确性,即与独立训练所有分支完全等价。实验证明,该框架在真实世界的Agentic RL和SFT任务中,能够在极低的额外开销下实现显著的训练加速(最高达6.2倍),且适用于Dense和MoE等多种模型架构。这不仅提升了训练效率,也为未来利用更复杂的树状思维链和多路径交互数据进行模型训练铺平了道路。
A6 附录
附录 A: Tree Packing 动态规划 (DP) 解决方案
A.1 单路径打包 (Single-Path Packing)
首先考虑简化情况,即每个训练步仅建立一条共享路径 $[r \to u]$。
* 定义 $L(u)$ 为共享前缀长度,$R(u)$ 为节点 $u$ 下所有叶子节点产生的剩余长度总和。
* 节点 $u$ 可行的条件是其共享前缀加上最大的子树路径长度不超过容量 $C$:
* 节省的长度 $S(u) = (n_u - 1)L(u)$。
* 定义 $DP(u)$ 为覆盖以 $u$ 为根的子树所能实现的最大长度节省。
* 递推公式为:
A.2 多路径打包 (Multi-Path Packing)
单路径策略在容量 $C$ 较大时可能无法实现最大复用(如图10所示,分两步分别复用 $r \to u \to v_1$ 和 $r \to u \to v_5$ 不如一次性复用 $r \to u$ 并分叉更优)。因此扩展到多路径设置:
* 状态定义:每个节点 $u$ 维护一个候选集合 $\mathcal{A}_u = \{(\mathbf{b}, \text{cost}(\mathbf{b}))\}$,其中 $\mathbf{b}$ 是一个向量,表示到达 $u$ 的多个遍历中已使用的容量。
* 操作:
1. Lift (提升):将子节点的状态向上传播,加上边长。
2. Bin Packing (装箱):将子节点提升上来的需求组合进容量为 $C_u$ 的箱子中。
- 这种精确DP方法理论最优但计算成本极高(NP-hard且状态空间指数增长)。因此在实际的大规模树中,作者使用前文提到的启发式DFS算法来逼近。
💬 评论讨论
欢迎在这里分享您的想法和见解!