Pendant la passe forward, les activations d'entrée de chaque couche sont nécessaires pendant la passe backward pour calculer les gradients. Normalement, toutes les activations sont stockées en mémoire. Avec le gradient checkpointing, seules les activations de certaines couches sont stockées. Pendant la passe backward, quand une activation non stockée est nécessaire, la passe forward est relancée depuis le checkpoint le plus proche pour la recalculer. Cela échange ~30% de calcul supplémentaire (recalcul des activations) contre ~5x d'économies de mémoire (ne pas toutes les stocker).
Le placement optimal des checkpoints dépend de l'architecture du modèle. L'approche la plus simple : checkpoint toutes les N couches (ex : chaque 3ème bloc Transformer). Plus sophistiqué : analyser les tailles d'activations par couche et placer les checkpoints pour minimiser la mémoire totale tout en limitant le recalcul. Certains frameworks (le torch.utils.checkpoint de PyTorch) rendent cela aussi simple que d'envelopper un appel de couche dans une fonction checkpoint.
Le gradient checkpointing se compose avec d'autres optimisations mémoire : la précision mixte (FP16/BF16 divise par deux la taille des activations), l'accumulation de gradients (des lots plus petits réduisent le pic de mémoire), et FSDP/DeepSpeed (répartition des paramètres entre les GPU). Ensemble, celles-ci peuvent réduire l'empreinte mémoire d'un modèle de 10–50x par rapport à un entraînement naïf en FP32, permettant d'entraîner des modèles bien plus grands que la mémoire d'un seul GPU. Cette pile d'optimisations est standard pour le fine-tuning de modèles 7B+.