Zubnet AI學習Wiki › Gradient Checkpointing
Training

Gradient Checkpointing

Activation Checkpointing, Rematerialization
訓練時用運算換記憶體的省記憶體技術。不是存前向傳播的所有中間激活(反向傳播需要的),gradient checkpointing 只在某些「檢查點」層存激活,反向傳播時重算其他的。這把記憶體使用減少 5–10 倍,代價是約 30% 更多運算。

為什麼重要

gradient checkpointing 就是讓在有限 GPU 記憶體上 fine-tune 大模型成為可能的東西。沒有它,一個 7B 模型訓練時光激活就可能需要 80+ GB,超過單 GPU 容量。有了 gradient checkpointing,同樣的模型能在 24GB 消費級 GPU 上 fine-tune。它是訓練中最常用的記憶體優化。

Deep Dive

During the forward pass, each layer's input activations are needed during the backward pass to compute gradients. Normally, all activations are stored in memory. With gradient checkpointing, only certain layers' activations are stored. During the backward pass, when an unstored activation is needed, the forward pass is re-run from the nearest checkpoint to recompute it. This trades ~30% extra compute (recomputing activations) for ~5x memory savings (not storing them all).

Checkpoint Placement

The optimal placement of checkpoints depends on the model architecture. The simplest approach: checkpoint every N layers (e.g., every 3rd Transformer block). More sophisticated: analyze the activation sizes per layer and place checkpoints to minimize total memory while limiting recomputation. Some frameworks (PyTorch's torch.utils.checkpoint) make this as simple as wrapping a layer call in a checkpoint function.

Combining with Other Techniques

Gradient checkpointing composes with other memory optimizations: mixed precision (FP16/BF16 halves activation size), gradient accumulation (smaller batches reduce peak memory), and FSDP/DeepSpeed (shard parameters across GPUs). Together, these can reduce a model's memory footprint by 10–50x compared to naive FP32 training, enabling training of models that are far larger than any single GPU's memory. This stack of optimizations is standard for fine-tuning 7B+ models.

相關概念

← 所有術語
← GQA Gradient Descent →