Durante el pase forward, las activaciones de entrada de cada capa son necesarias durante el pase backward para calcular gradientes. Normalmente, todas las activaciones se almacenan en memoria. Con gradient checkpointing, solo se almacenan las activaciones de ciertas capas. Durante el pase backward, cuando se necesita una activación no almacenada, se vuelve a ejecutar el pase forward desde el checkpoint más cercano para recalcularla. Esto intercambia ~30% de cómputo extra (recalcular activaciones) por ~5x de ahorro de memoria (no almacenarlas todas).
La ubicación óptima de los checkpoints depende de la arquitectura del modelo. El enfoque más simple: hacer checkpoint cada N capas (por ejemplo, cada tercer bloque Transformer). Más sofisticado: analizar los tamaños de activación por capa y colocar checkpoints para minimizar la memoria total mientras se limita el recálculo. Algunos frameworks (el torch.utils.checkpoint de PyTorch) hacen esto tan simple como envolver una llamada de capa en una función de checkpoint.
Gradient checkpointing se compone con otras optimizaciones de memoria: precisión mixta (FP16/BF16 reduce el tamaño de activación a la mitad), acumulación de gradientes (lotes más pequeños reducen la memoria pico) y FSDP/DeepSpeed (fragmentar parámetros entre GPUs). Juntas, estas pueden reducir la huella de memoria de un modelo de 10–50x comparado con el entrenamiento ingenuo en FP32, permitiendo el entrenamiento de modelos mucho más grandes que la memoria de cualquier GPU individual. Esta pila de optimizaciones es estándar para ajustar modelos de 7B+.