Applied ML

JIT Compilation

Stop interpreting the model one op at a time

01 · First principlesWhat eager mode actually pays

Eager PyTorch is an interpreter: every x + y travels through Python dispatch, type and device resolution, autograd bookkeeping, and a CUDA kernel launch — often 10–30 µs of overhead for a kernel that may run for 5 µs. A transformer forward issues thousands of such ops. While each kernel is large, the overhead hides behind asynchronous execution; shrink the kernels (small batch, inference, the long tail of pointwise ops) and the GPU starts idling between launches, visible as gaps in any trace.

The deeper cost is structural: an interpreter sees one op at a time, so it cannot fuse. A chain like bias-add → GeLU → dropout → residual writes its tensor to HBM and reads it back at every arrow. Those ops are bandwidth-bound (arithmetic intensity below 1 — see the roofline), so the redundant round-trips are the whole cost, and no amount of faster individual kernels removes them. Removing them requires seeing the chain, and seeing the chain requires a graph.

02 · The mechanismCapture, specialise, fuse

  1. Capture: turn the Python function into a graph of tensor ops — by tracing actual execution, by analysing bytecode, or by having the user write graph-friendly code in the first place.
  2. Specialise: fix what eager left dynamic — dtypes, devices, and usually shapes. A matmul known to be (4096 × 4096) @ (4096 × 11008) in bf16 can get a tile configuration chosen for exactly that problem.
  3. Fuse and lower: merge pointwise chains (and increasingly reductions) into single kernels that keep intermediates in registers; emit code; cache it keyed on the specialisation.
eager:  k ops → k launches + 2k HBM round-trips   ⟹   compiled:  1 launch + 1 read + 1 write
for a pointwise chain — the fusion win in one line

The speedup is mostly fusion plus launch-overhead removal, not "better matmuls" — the big matmuls were already going to cuBLAS-grade kernels. Typical training speedups from torch.compile land around 1.2–1.8×, larger when the model is small-op-heavy, smaller when it is one giant matmul.

03 · The implementationstorch.compile and XLA

torch.compile (Dynamo + Inductor). Dynamo intercepts Python bytecode and extracts graphs while letting genuinely dynamic Python fall back to eager — when it hits something it cannot capture (a data-dependent branch, an unsupported call), it inserts a graph break and stitches eager execution around it, so almost any model runs, just with fewer fusion opportunities per break. Each captured graph carries guards — recorded assumptions like "x.shape[0] == 32" — checked on every call; a failed guard means recompilation. Inductor then lowers the graph, generating Triton kernels for the fused regions. The design bet: never demand that users rewrite their model, accept partial graphs as the price.

XLA. The opposite bet. XLA consumes whole programs in a functional IR (HLO) and optimises globally — aggressive fusion, layout assignment, buffer reuse — and it is the only path to TPUs. JAX is designed around it: jit-traced functions must be pure and shape-static, and in exchange XLA gets a complete, break-free graph every time. Whole-program views also let it schedule communication, overlapping collectives with compute in ways per-op execution cannot.

04 · The billWhat compilation costs

CostMechanismMitigation
Compile latency First call pays seconds to minutes (capture, autotuning, codegen); painful in notebooks and short jobs Persistent caches; amortise over long runs — a 1.3× speedup repays minutes of compile within hours
Recompilation on shape change Guards specialise on shapes; every new sequence length or batch size is a fresh compile, and varied-length workloads can recompile forever Pad to buckets of fixed shapes; dynamic=True (symbolic shapes — fewer recompiles, weaker codegen)
Debugging opacity Stack traces point into generated kernels; print and breakpoints inside compiled regions break capture or lie; numerics differ slightly from eager (fused ops round differently — see non-associativity) Develop eager, compile late; bisect with graph breaks; TORCH_LOGS to see what Dynamo did
Capture restrictions Data-dependent control flow, side effects, and exotic Python either break the graph (Dynamo) or are simply forbidden (XLA) Keep hot paths tensor-pure; move branching outside the compiled region
The recurring trap: a model that recompiles every step is slower than eager and harder to debug — the worst point on the curve. Watch the recompile counters (torch._dynamo.utils.counters) before celebrating a "successful" compile; one un-padded dynamic dimension is usually the culprit.

05 · JudgmentWhen to reach for it

Compile
Long training runs; static or bucketable shapes; inference serving at fixed shapes; models thick with pointwise ops and norms. Free lunch territory once the cache is warm.
Stay eager
Active model development and debugging; highly dynamic shapes or control flow; short experiments where compile time dominates; anywhere a confusing failure costs more than 1.3× speed.

The honest framing: JIT compilation trades flexibility you were not using (arbitrary Python between ops, fully dynamic shapes) for bandwidth and overhead wins. When you genuinely use that flexibility, the trade reverses — which is precisely the design split explored in JAX vs PyTorch vs TensorFlow.

Mental Model