Zubnet AIApprendreWiki › Gradient Checkpointing
Entraînement

Gradient Checkpointing

Aussi appelé : Activation checkpointing, rematérialisation
Une technique d'économie de mémoire qui échange du calcul contre de la mémoire pendant l'entraînement. Au lieu de stocker toutes les activations intermédiaires de la passe forward (nécessaires pour la rétropropagation), le gradient checkpointing ne stocke les activations qu'à certaines couches "checkpoint" et recalcule les autres pendant la passe backward. Cela réduit l'utilisation mémoire jusqu'à 5–10x au prix d'environ 30% de calcul en plus.

Pourquoi c'est important

Le gradient checkpointing est ce qui rend possible le fine-tuning de grands modèles sur une mémoire GPU limitée. Sans lui, un modèle 7B pourrait nécessiter 80+ Go juste pour les activations pendant l'entraînement, dépassant la capacité d'un seul GPU. Avec le gradient checkpointing, le même modèle peut être fine-tuné sur un GPU grand public de 24 Go. C'est l'optimisation mémoire la plus couramment utilisée pour l'entraînement.

En profondeur

Pendant la passe forward, les activations d'entrée de chaque couche sont nécessaires pendant la passe backward pour calculer les gradients. Normalement, toutes les activations sont stockées en mémoire. Avec le gradient checkpointing, seules les activations de certaines couches sont stockées. Pendant la passe backward, quand une activation non stockée est nécessaire, la passe forward est relancée depuis le checkpoint le plus proche pour la recalculer. Cela échange ~30% de calcul supplémentaire (recalcul des activations) contre ~5x d'économies de mémoire (ne pas toutes les stocker).

Placement des checkpoints

Le placement optimal des checkpoints dépend de l'architecture du modèle. L'approche la plus simple : checkpoint toutes les N couches (ex : chaque 3ème bloc Transformer). Plus sophistiqué : analyser les tailles d'activations par couche et placer les checkpoints pour minimiser la mémoire totale tout en limitant le recalcul. Certains frameworks (le torch.utils.checkpoint de PyTorch) rendent cela aussi simple que d'envelopper un appel de couche dans une fonction checkpoint.

Combinaison avec d'autres techniques

Le gradient checkpointing se compose avec d'autres optimisations mémoire : la précision mixte (FP16/BF16 divise par deux la taille des activations), l'accumulation de gradients (des lots plus petits réduisent le pic de mémoire), et FSDP/DeepSpeed (répartition des paramètres entre les GPU). Ensemble, celles-ci peuvent réduire l'empreinte mémoire d'un modèle de 10–50x par rapport à un entraînement naïf en FP32, permettant d'entraîner des modèles bien plus grands que la mémoire d'un seul GPU. Cette pile d'optimisations est standard pour le fine-tuning de modèles 7B+.

Concepts connexes

← Tous les termes
← GQA Grand modèle de langage →