Applied ML

JAX vs PyTorch vs TensorFlow

Three answers to one question: when does the graph exist?

01 · First principlesThe one design question

Every framework needs a graph of the computation, for two non-negotiable reasons: autodiff must walk the chain of operations backward, and the compiler needs the whole chain to fuse and schedule it (see JIT compilation). The frameworks differ less in kernels — everyone ultimately calls cuDNN, cuBLAS, or XLA — than in when that graph comes into existence, and what the user must promise to make it exist. That single choice cascades into debuggability, performance, and what the code feels like to write.

02 · Three answersBefore, during, instead

TF1 · define-then-run
Build the graph explicitly, in advance; execute it later inside a session. The compiler sees everything; the human sees nothing — a shape error surfaces at run time, far from the line that caused it, and print prints a symbolic node. Maximal optimisation, minimal ergonomics. Research users fled; TF2 capitulated to eager with tf.function retrofitted on top.
PyTorch · define-by-run
The "graph" is just the trail of operations actually executed, recorded by autograd as Python runs. Code is Python: breakpoints work, prints print numbers, data-dependent control flow is just an if. The price is that no complete graph exists ahead of time — the gap torch.compile now works to close from the eager side.
JAX · trace pure functions
Neither build a graph nor merely record one — write a pure function of arrays, and JAX traces it (runs it with abstract values) to obtain a complete graph on demand, which XLA then compiles. Eager-feeling code, whole-program compilation. The promise extracted in return: the function must be pure, and that promise is load-bearing.
The price of purity
No in-place mutation (x.at[i].set(v) returns a new array), explicit RNG keys threaded by hand, parameters passed as arguments rather than living in modules, and traced branches that cannot depend on values (jax.lax.cond instead of if). Side effects do not fail loudly — a print inside jit fires once, at trace time, then never again.

03 · JAX's payoffComposable transforms

Purity is what JAX charges; this is what it buys. Because a traced function is a closed mathematical object, transformations of it compose like functions do:

TransformTakes f and returns…Replaces
grad(f)The gradient function, itself traceableAutograd tape machinery
jit(f)f compiled whole by XLAtorch.compile, without graph breaks
vmap(f)f vectorised over a new batch axisHand-written batching, loops
pmap / shard_map(f)f running SPMD across devices, collectives insideMuch of the DDP/FSDP wrapper stack

The composition is the point: jit(pmap(grad(f))) is per-example gradients, batched, differentiated, compiled, and distributed — one line, no framework machinery in sight. Things that are research projects in other ecosystems (per-sample gradients for DP, meta-gradients through training steps, ensembles via vmap over parameter stacks) are compositions in JAX. Sharding follows the same philosophy: annotate how arrays are laid out across the mesh, and the XLA partitioner derives the communication, rather than the user orchestrating it imperatively.

04 · Side by sideThe comparison table

PyTorchJAXTensorFlow
Graph existsDuring execution (autograd trail); ahead-of-time via compileAt trace time, on demand, wholeTF1: before. TF2: eager + tf.function retrace
State & paramsMutable modules, in-place opsPure functions; state threaded explicitly (pytrees)Mutable tf.Variable in Keras objects
DebuggingNative Python, best in classGood eager; inside jit, trace-time surprisesHistorically the complaint that built PyTorch
Compiler storytorch.compile, partial graphs, guardsXLA-native, whole program, designed-inXLA available, grafted on
HardwareGPUs first; TPU support secondaryTPUs first-class; GPUs well supportedTPUs supported; the original TPU framework
RandomnessGlobal stateful RNGExplicit keys (reproducible, verbose)Global, with seeds
Ecosystem centerResearch papers, HF modelsScaling shops, RL, scientific computingLegacy production, mobile (TFLite)

05 · Honest stateWho uses what, and why

PyTorch is the default. The large majority of new research code, the Hugging Face ecosystem, and most open-weight model releases are PyTorch-first. Define-by-run won the argument: researcher iteration speed beat compiler convenience, and the compiler gap has since been narrowed from the eager side rather than the reverse.

JAX is the choice of TPU and large-scale shops. Google DeepMind, Anthropic, and a cluster of scaling-focused labs run on it, because whole-program XLA plus mesh sharding is a genuinely better substrate at thousand-chip scale, and because functional purity pays increasing dividends as systems grow (reproducibility, checkpointable state, no hidden mutation). Its research mindshare outside those shops remains a minority position.

TensorFlow is legacy production. Plenty still runs — established serving stacks, TFLite on mobile, older recommender systems — but new projects rarely start there, and Google's own center of gravity moved to JAX. Choosing it for new work in 2026 needs a specific reason (usually an existing deployment pipeline).

Choosing: follow your collaborators and your hardware. Research with the community's code: PyTorch. TPU pods, or a team that thinks functionally and needs exotic transforms: JAX. An existing TF production stack: stay until there is a reason to move. The frameworks converge in capability; the lasting differences are the contracts they ask you to sign.
Mental Model