Zubnet AI學習Wiki › 梯度檢查點
訓練

梯度檢查點

別名:激活檢查點、重新具體化

一種以計算換記憶體的訓練節省記憶體技術。梯度檢查點不是儲存前向傳遞中的所有中間激活(反向傳播所需),而是只在某些「檢查點」層儲存激活,在反向傳遞期間重新計算其他的。這可以將記憶體使用量減少多達 5–10 倍,代價是大約 30% 的額外計算。

為什麼重要

梯度檢查點使得在有限的 GPU 記憶體上微調大型模型成為可能。沒有它,一個 7B 模型在訓練期間可能僅激活就需要 80+ GB,超過了單一 GPU 的容量。有了梯度檢查點,同一模型可以在 24GB 的消費級 GPU 上進行微調。它是訓練中最常用的記憶體最佳化方法。

深度解析

在前向傳遞期間,每層的輸入激活在反向傳遞中計算梯度時需要用到。通常,所有激活都儲存在記憶體中。使用梯度檢查點,只有某些層的激活被儲存。在反向傳遞中,當需要未儲存的激活時,從最近的檢查點重新執行前向傳遞來重新計算它。這以約 30% 的額外計算(重新計算激活)換取約 5 倍的記憶體節省(不需要全部儲存)。

檢查點放置

檢查點的最佳放置取決於模型架構。最簡單的方法:每 N 層設一個檢查點(例如每第 3 個 Transformer 區塊)。更精密的方法:分析每層的激活大小,放置檢查點以最小化總記憶體,同時限制重新計算。某些框架(PyTorch 的 torch.utils.checkpoint)使這像在檢查點函數中包裝層呼叫一樣簡單。

與其他技術結合

梯度檢查點可以與其他記憶體最佳化組合:混合精度(FP16/BF16 將激活大小減半)、梯度累積(較小的批次減少峰值記憶體)、以及 FSDP/DeepSpeed(跨 GPU 分片參數)。結合在一起,這些可以將模型的記憶體佔用減少 10–50 倍,相比簡單的 FP32 訓練,使得訓練遠超任何單一 GPU 記憶體的模型成為可能。這套最佳化組合是微調 7B+ 模型的標準配置。

相關概念

← 所有術語
← 月之暗面 標註 →