Applied ML

Gradient Checkpointing

Trade recomputation for activation memory

01 · First principlesThe memory FSDP cannot touch

Sharding divides model state by N, but there is a second memory consumer it never sees: activations. Backward needs the inputs of every layer to compute that layer's weight gradients, so the default autograd contract is brutal — everything the forward pass produced stays resident until backward consumes it.

Activation memory scales as layers × batch × sequence length × hidden width. For a transformer it is roughly tens of bytes per token per layer even in bf16; at long sequence lengths or deep stacks it routinely exceeds the 16 bytes/param of model state. Model state is fixed at startup; activations are what actually produce the mid-training OOM when someone doubles the context length.

02 · Failure firstThe naive alternatives both lose

Store everything (default)
Memory grows as O(L) in depth. At L = 100 layers and 8k context, activations alone can run to hundreds of GB per device. Forward is paid once; memory is the casualty.
Store nothing, recompute from input
O(1) memory, but reaching layer k's activations during backward means rerunning layers 1..k. Summed over all layers that is O(L²) compute — a 100-layer model pays roughly 50 extra forwards.

One end of the spectrum is unaffordable in memory, the other in FLOPs. Checkpointing is the observation that the spectrum has a usable middle.

03 · The mechanismCheckpoints and segments

  1. Forward: keep activations only at chosen checkpoint layers; everything between two checkpoints is computed and immediately discarded.
  2. Backward, per segment (last to first): rerun the forward from the segment's checkpoint to repopulate its activations, then run backward through the segment and free them again.
  3. Peak activation memory = the checkpoints + one segment's worth of live activations.

With L layers split into segments of length s, memory is roughly L/s checkpoints plus s live activations. Minimising L/s + s gives s = √L:

memory  ∝  L/s + s   ⟶   min at s = √L   ⟹   O(√L) memory, one extra forward

The compute overhead is exactly one additional forward pass regardless of s, because each layer is forward-computed twice in total. A forward is roughly one third of a training step's FLOPs (backward costs about two forwards), so checkpointing everything costs about 33% more compute — in practice often 20–30% more wall time, since recompute overlaps other work imperfectly.

L = 16 LAYERS · CHECKPOINT EVERY √L = 4 FORWARD ■ STORED (CHECKPOINT) □ COMPUTED, DISCARDED BACKWARD …earlier segments wait… RECOMPUTE SEGMENT FROM ITS CHECKPOINT, THEN BACKPROP THROUGH IT, THEN FREE

Only √L checkpoints survive forward. Backward revives one segment at a time; peak memory is checkpoints + one live segment.

04 · PracticeWhat actually gets checkpointed

Nobody solves the optimisation per-model; the working convention is one checkpoint per transformer block (torch.utils.checkpoint.checkpoint around each block, or activation_checkpointing policies in FSDP). Two practical wrinkles deserve attention:

05 · The judgment callWhen the trade is right

SituationCheckpoint?Why
OOM, and batch size cannot shrink furtherYes~30% slower beats not training at all
Memory freed lets batch size grow enough to raise GPU utilisationOften yesBigger batches can repay the recompute with interest
Long-context trainingUsually yesActivations scale with sequence length; state does not
Compute-bound and memory comfortableNoYou would pay 33% FLOPs for memory you do not need
Order of operations: checkpointing trades the cheap resource (FLOPs are growing faster than HBM capacity) for the scarce one. It composes with everything — FSDP shards state, checkpointing caps activations, accumulation shrinks the live micro-batch. Large-model recipes typically use all three.
Mental Model