Zubnet AI学习Wiki › 梯度检查点
训练

梯度检查点

别名:激活检查点、重计算
一种在训练过程中用计算换内存的节省内存技术。梯度检查点不存储前向传播的所有中间激活(反向传播需要),而是只在某些"检查点"层存储激活,并在反向传播过程中重新计算其他激活。这以约30%的额外计算为代价,将内存使用减少高达5–10倍。

为什么重要

梯度检查点使在有限GPU内存上微调大型模型成为可能。没有它,7B模型在训练期间可能仅激活就需要80+ GB,超过单块GPU的容量。有了梯度检查点,同一模型可以在24GB的消费级GPU上微调。它是训练中最常用的内存优化。

深度解析

在前向传播过程中,每层的输入激活在反向传播中需要用于计算梯度。通常,所有激活都存储在内存中。使用梯度检查点,只存储某些层的激活。在反向传播过程中,当需要未存储的激活时,从最近的检查点重新运行前向传播来重新计算它。这以约30%的额外计算(重新计算激活)换取约5倍的内存节省(不需要存储所有激活)。

检查点放置

检查点的最佳放置取决于模型架构。最简单的方法:每N层设置检查点(例如每3个Transformer块)。更复杂的方法:分析每层的激活大小,放置检查点以最小化总内存同时限制重计算。一些框架(PyTorch的torch.utils.checkpoint)使这简单到只需将层调用包装在检查点函数中。

与其他技术的结合

梯度检查点与其他内存优化可以组合:混合精度(FP16/BF16将激活大小减半)、梯度累积(更小的批次减少峰值内存)和FSDP/DeepSpeed(跨GPU分片参数)。结合使用,这些可以将模型的内存占用与朴素FP32训练相比减少10–50倍,使训练远大于任何单块GPU内存的模型成为可能。这套优化组合是微调7B+模型的标准配置。

相关概念

← 所有术语
← 梯度下降 检查点 →