Zubnet AIAprenderWiki › Gradient Checkpointing
Entrenamiento

Gradient Checkpointing

También conocido como: Activation Checkpointing, Rematerialización
Una técnica de ahorro de memoria que intercambia cómputo por memoria durante el entrenamiento. En lugar de almacenar todas las activaciones intermedias del pase forward (necesarias para la retropropagación), gradient checkpointing solo almacena activaciones en ciertas capas "checkpoint" y recalcula las demás durante el pase backward. Esto reduce el uso de memoria hasta 5–10x a cambio de ~30% más de cómputo.

Por qué importa

Gradient checkpointing es lo que hace posible ajustar modelos grandes con memoria GPU limitada. Sin él, un modelo de 7B podría necesitar más de 80 GB solo para activaciones durante el entrenamiento, excediendo la capacidad de una sola GPU. Con gradient checkpointing, el mismo modelo se puede ajustar en una GPU de consumo de 24GB. Es la optimización de memoria más comúnmente usada para entrenamiento.

En profundidad

Durante el pase forward, las activaciones de entrada de cada capa son necesarias durante el pase backward para calcular gradientes. Normalmente, todas las activaciones se almacenan en memoria. Con gradient checkpointing, solo se almacenan las activaciones de ciertas capas. Durante el pase backward, cuando se necesita una activación no almacenada, se vuelve a ejecutar el pase forward desde el checkpoint más cercano para recalcularla. Esto intercambia ~30% de cómputo extra (recalcular activaciones) por ~5x de ahorro de memoria (no almacenarlas todas).

Ubicación de los checkpoints

La ubicación óptima de los checkpoints depende de la arquitectura del modelo. El enfoque más simple: hacer checkpoint cada N capas (por ejemplo, cada tercer bloque Transformer). Más sofisticado: analizar los tamaños de activación por capa y colocar checkpoints para minimizar la memoria total mientras se limita el recálculo. Algunos frameworks (el torch.utils.checkpoint de PyTorch) hacen esto tan simple como envolver una llamada de capa en una función de checkpoint.

Combinación con otras técnicas

Gradient checkpointing se compone con otras optimizaciones de memoria: precisión mixta (FP16/BF16 reduce el tamaño de activación a la mitad), acumulación de gradientes (lotes más pequeños reducen la memoria pico) y FSDP/DeepSpeed (fragmentar parámetros entre GPUs). Juntas, estas pueden reducir la huella de memoria de un modelo de 10–50x comparado con el entrenamiento ingenuo en FP32, permitiendo el entrenamiento de modelos mucho más grandes que la memoria de cualquier GPU individual. Esta pila de optimizaciones es estándar para ajustar modelos de 7B+.

Conceptos relacionados

← Todos los términos