梯度检查点使在有限GPU内存上微调大型模型成为可能。没有它,7B模型在训练期间可能仅激活就需要80+ GB,超过单块GPU的容量。有了梯度检查点,同一模型可以在24GB的消费级GPU上微调。它是训练中最常用的内存优化。
在前向传播过程中,每层的输入激活在反向传播中需要用于计算梯度。通常,所有激活都存储在内存中。使用梯度检查点,只存储某些层的激活。在反向传播过程中,当需要未存储的激活时,从最近的检查点重新运行前向传播来重新计算它。这以约30%的额外计算(重新计算激活)换取约5倍的内存节省(不需要存储所有激活)。
检查点的最佳放置取决于模型架构。最简单的方法:每N层设置检查点(例如每3个Transformer块)。更复杂的方法:分析每层的激活大小,放置检查点以最小化总内存同时限制重计算。一些框架(PyTorch的torch.utils.checkpoint)使这简单到只需将层调用包装在检查点函数中。
梯度检查点与其他内存优化可以组合:混合精度(FP16/BF16将激活大小减半)、梯度累积(更小的批次减少峰值内存)和FSDP/DeepSpeed(跨GPU分片参数)。结合使用,这些可以将模型的内存占用与朴素FP32训练相比减少10–50倍,使训练远大于任何单块GPU内存的模型成为可能。这套优化组合是微调7B+模型的标准配置。