Durante o forward pass, as ativações de entrada de cada camada são necessárias durante o backward pass para computar gradientes. Normalmente, todas as ativações são armazenadas em memória. Com gradient checkpointing, apenas ativações de certas camadas são armazenadas. Durante o backward pass, quando uma ativação não armazenada é necessária, o forward pass é re-executado a partir do checkpoint mais próximo para recalculá-la. Isso troca ~30% de computação extra (recalculando ativações) por ~5x de economia de memória (não armazenando todas).
O posicionamento ideal de checkpoints depende da arquitetura do modelo. A abordagem mais simples: fazer checkpoint a cada N camadas (ex.: a cada 3 blocos Transformer). Mais sofisticado: analisar os tamanhos de ativação por camada e posicionar checkpoints para minimizar memória total limitando a recomputação. Alguns frameworks (torch.utils.checkpoint do PyTorch) tornam isso tão simples quanto envolver uma chamada de camada em uma função de checkpoint.
Gradient checkpointing se compõe com outras otimizações de memória: precisão mista (FP16/BF16 reduz pela metade o tamanho das ativações), acumulação de gradiente (lotes menores reduzem pico de memória) e FSDP/DeepSpeed (fragmenta parâmetros entre GPUs). Juntas, essas técnicas podem reduzir o footprint de memória de um modelo em 10–50x comparado a treinamento ingênuo em FP32, permitindo treinar modelos muito maiores que a memória de qualquer GPU individual. Essa pilha de otimizações é padrão para fine-tuning de modelos 7B+.