Stop interpreting the model one op at a time
Eager PyTorch is an interpreter: every x + y travels through Python dispatch, type and device resolution, autograd bookkeeping, and a CUDA kernel launch — often 10–30 µs of overhead for a kernel that may run for 5 µs. A transformer forward issues thousands of such ops. While each kernel is large, the overhead hides behind asynchronous execution; shrink the kernels (small batch, inference, the long tail of pointwise ops) and the GPU starts idling between launches, visible as gaps in any trace.
The deeper cost is structural: an interpreter sees one op at a time, so it cannot fuse. A chain like bias-add → GeLU → dropout → residual writes its tensor to HBM and reads it back at every arrow. Those ops are bandwidth-bound (arithmetic intensity below 1 — see the roofline), so the redundant round-trips are the whole cost, and no amount of faster individual kernels removes them. Removing them requires seeing the chain, and seeing the chain requires a graph.
The speedup is mostly fusion plus launch-overhead removal, not "better matmuls" — the big matmuls were already going to cuBLAS-grade kernels. Typical training speedups from torch.compile land around 1.2–1.8×, larger when the model is small-op-heavy, smaller when it is one giant matmul.
torch.compile (Dynamo + Inductor). Dynamo intercepts Python bytecode and extracts graphs while letting genuinely dynamic Python fall back to eager — when it hits something it cannot capture (a data-dependent branch, an unsupported call), it inserts a graph break and stitches eager execution around it, so almost any model runs, just with fewer fusion opportunities per break. Each captured graph carries guards — recorded assumptions like "x.shape[0] == 32" — checked on every call; a failed guard means recompilation. Inductor then lowers the graph, generating Triton kernels for the fused regions. The design bet: never demand that users rewrite their model, accept partial graphs as the price.
XLA. The opposite bet. XLA consumes whole programs in a functional IR (HLO) and optimises globally — aggressive fusion, layout assignment, buffer reuse — and it is the only path to TPUs. JAX is designed around it: jit-traced functions must be pure and shape-static, and in exchange XLA gets a complete, break-free graph every time. Whole-program views also let it schedule communication, overlapping collectives with compute in ways per-op execution cannot.
| Cost | Mechanism | Mitigation |
|---|---|---|
| Compile latency | First call pays seconds to minutes (capture, autotuning, codegen); painful in notebooks and short jobs | Persistent caches; amortise over long runs — a 1.3× speedup repays minutes of compile within hours |
| Recompilation on shape change | Guards specialise on shapes; every new sequence length or batch size is a fresh compile, and varied-length workloads can recompile forever | Pad to buckets of fixed shapes; dynamic=True (symbolic shapes — fewer recompiles, weaker codegen) |
| Debugging opacity | Stack traces point into generated kernels; print and breakpoints inside compiled regions break capture or lie; numerics differ slightly from eager (fused ops round differently — see non-associativity) |
Develop eager, compile late; bisect with graph breaks; TORCH_LOGS to see what Dynamo did |
| Capture restrictions | Data-dependent control flow, side effects, and exotic Python either break the graph (Dynamo) or are simply forbidden (XLA) | Keep hot paths tensor-pure; move branching outside the compiled region |
torch._dynamo.utils.counters) before celebrating a "successful" compile; one un-padded dynamic dimension is usually the culprit.The honest framing: JIT compilation trades flexibility you were not using (arbitrary Python between ops, fully dynamic shapes) for bandwidth and overhead wins. When you genuinely use that flexibility, the trade reverses — which is precisely the design split explored in JAX vs PyTorch vs TensorFlow.