Forward pass के दौरान, प्रत्येक layer के input activations backward pass के दौरान gradients की गणना करने के लिए आवश्यक हैं। सामान्य रूप से, सभी activations मेमोरी में संग्रहीत होते हैं। Gradient checkpointing के साथ, केवल कुछ लेयर्स के activations संग्रहीत होते हैं। Backward pass के दौरान, जब एक असंग्रहीत activation की आवश्यकता होती है, तो forward pass निकटतम checkpoint से इसे पुन: गणना करने के लिए फिर से चलाया जाता है। यह ~30% अतिरिक्त compute (activations की पुन: गणना) को ~5x मेमोरी बचत (उन सभी को संग्रहीत न करने) के बदले trade करता है।
Checkpoints का इष्टतम प्लेसमेंट मॉडल आर्किटेक्चर पर निर्भर करता है। सबसे सरल दृष्टिकोण: हर N लेयर्स पर checkpoint करें (जैसे, हर 3rd Transformer block)। अधिक परिष्कृत: प्रति layer activation sizes का विश्लेषण करें और re-computation को सीमित करते हुए कुल मेमोरी को कम करने के लिए checkpoints रखें। कुछ frameworks (PyTorch का torch.utils.checkpoint) इसे एक checkpoint function में layer call को wrap करने जितना सरल बनाते हैं।
Gradient checkpointing अन्य मेमोरी ऑप्टिमाइज़ेशन के साथ compose करता है: mixed precision (FP16/BF16 activation size को आधा करता है), gradient accumulation (छोटे batches peak मेमोरी कम करते हैं), और FSDP/DeepSpeed (GPUs में parameters shard करें)। साथ में, ये naive FP32 प्रशिक्षण की तुलना में एक मॉडल के मेमोरी footprint को 10–50x तक कम कर सकते हैं, जो किसी भी single GPU की मेमोरी से कहीं बड़े मॉडलों का प्रशिक्षण सक्षम करते हैं। ऑप्टिमाइज़ेशन का यह stack 7B+ मॉडलों को fine-tune करने के लिए मानक है।