融合线性交叉熵

融合线性交叉熵#

LoongForge 为模型输出层提供了一种显存优化方案。通过将 hidden @ weight.T 线性投影与交叉熵损失融合为单一操作,并结合分块计算,显著降低了词表投影阶段的峰值显存占用。

在标准训练中,输出层生成形状为 (num_tokens, vocab_size) 的完整 Logits 张量,该张量在反向传播期间再次被保留,导致显存开销翻倍。对于典型配置(num_tokens=16384, vocab_size=129280),仅 Logits 相关的显存就可能达到 ~40 GB。本优化通过两步递进方案解决此问题:

  • 步骤 1(算子融合):将线性投影和交叉熵融合为单一 autograd Function,反向传播由框架控制,仅保存轻量级统计信息(每个 Token 的最大值和指数和),无需在前向和反向传播之间存储完整 Logits

  • 步骤 2(分块计算):沿词表维度将权重切分为小块(默认 vocab_per_split=3072),使用在线 Softmax 算法逐块计算并立即丢弃,因此完整 Logits 张量永远不会被实例化

LoongForge 提供两种实现路径,框架根据 GPU 架构自动选择

使用方法#

在训练启动脚本中添加以下参数:

--cross-entropy-loss-fusion \
--cross-entropy-fusion-impl linear

1. 通用实现#

LoongForge 使用混合精度、缓冲区复用、原地操作和 Autograd 等策略实现了纯 PyTorch 通用版本,使此优化能在任何 CUDA GPU 上运行,同时相比原生 Torch 实现也具有显著的性能优势。

核心设计:预分配宽度为 vocab_per_split 的小缓冲区,将每次矩阵乘法直接写入该缓冲区(通过 out= 参数),结果在同一循环迭代中立即被在线 Softmax 消费,并在下一轮被覆盖——完整 Logits 永远不会在 Python 层面累积:

matmul_buf = torch.empty((num_tokens, vocab_per_split), ...)  # allocate only one chunk size

for split_idx in range(num_splits):
    torch.matmul(hidden, weight[v_start:v_end].t(), out=matmul_buf)  # write into reused buffer
    logits_chunk.sub_(new_max.unsqueeze(1)).exp_()                    # in-place, immediately consumed
    accumulate.mul_(torch.exp(maximum - new_max)).add_(chunk_sum)     # update statistics
    maximum = new_max
    # next matmul directly overwrites the buffer, complete logits never existed

反向传播同样逐块重计算,使用前向传播保存的 maximumaccumulate(形状均为 (num_tokens,))恢复每个块的 Softmax 概率,无需保存完整 Logits。

特性#

  • 适用于任何 CUDA GPU,无硬件限制

  • 在 A800 上比原生实现快 22~26%

  • 在线 Softmax 保证与原生实现数值完全一致(Loss/梯度差异 < 1e-5)

  • 支持 DP / TP / SP 并行策略,支持 FP8 训练


2. 调优参数#

通过环境变量控制分块大小,平衡显存占用和性能:

# 默认值 3072,显存与性能的最优平衡(推荐)
export LCE_GENERIC_FWD_VOCAB_SPLIT_SIZE=3072
export LCE_GENERIC_BWD_VOCAB_SPLIT_SIZE=3072

# 显存充足时增大分块大小以提高 GPU 利用率
export LCE_GENERIC_FWD_VOCAB_SPLIT_SIZE=8192

# 显存极度受限时减小分块大小
export LCE_GENERIC_FWD_VOCAB_SPLIT_SIZE=512