General ML

BatchNorm / LayerNorm / RMSNorm

Same normalisation, three choices of axis

01 · First principlesWhy normalise inside the network

Deep nets are products of layers, so activation scale compounds with depth — bad initialisation or a few large updates and the distributions inside the network drift, saturating activations and destabilising training. A norm layer re-standardises activations every forward pass: subtract a mean, divide by a standard deviation, then restore expressiveness with learned scale and shift.

x̂ = (x − μ) / √(σ² + ε),    y = γ x̂ + β

The original story was "internal covariate shift": each layer's input distribution keeps moving, so each layer chases a moving target, and pinning the distribution fixes that. The story is intuitive and, as stated, mostly wrong — later work showed BatchNorm helps even when the shift is deliberately re-injected. The measured benefit is geometric: normalisation makes the loss landscape smoother (smaller, more stable gradients; better effective Lipschitz constants), which permits much higher learning rates, faster convergence, and indifference to initialisation. Wrong story, right layer.

All three norms compute the same μ-and-σ recipe. The entire difference is the axis along which the statistics are taken, and that one choice decides where each can be used.

02 · BatchNormStatistics across the batch

BatchNorm normalises each feature (channel) using the mean and variance over the examples in the mini-batch. That choice has teeth:

  1. Small batches break it. The statistics are estimates from |B| samples; at batch size 4 they are noise, and the layer injects that noise into every activation. Quality degrades sharply as batches shrink.
  2. Train and test disagree. At test time there is no batch (or only one example), so inference uses running averages collected during training. Two different functions, one model — a classic source of subtle bugs and train/eval gaps.
  3. Sequences fit badly. Coupling examples in a batch is awkward with variable-length sequences and impossible in autoregressive decoding, where tokens arrive one at a time.

Where the batch is large and examples are exchangeable — convnet image classification (CNNs) — BatchNorm remains excellent, and its batch coupling even acts as a mild regulariser.

03 · LayerNorm / RMSNormStatistics within one example

LayerNorm flips the axis: for each individual example (each token, in a transformer), normalise across its feature dimension. Every property that hurt BatchNorm inverts: no dependence on batch size, identical behaviour in training and inference, and perfectly defined for a single token in a sequence of any length. That batch-independence — not any subtlety — is why transformers use LayerNorm.

RMSNorm is LayerNorm minus the mean: divide by the root-mean-square of the features and skip both μ and β.

y = γ · x / RMS(x),    RMS(x) = √( (1/d) Σi xi² )

One fewer reduction, fewer parameters, measurably faster — and in practice it trains as well as LayerNorm, suggesting the re-scaling was doing nearly all the work and the re-centring was ballast. Llama-family and most recent LLMs use RMSNorm.

BATCHNORM · one feature, all examples BATCH → FEATURES → LAYERNORM · one example, all features FEATURES →

The whole difference is the highlighted axis. BatchNorm couples examples vertically; LayerNorm (and RMSNorm) stays inside one row.

04 · Side by sideWhere each lives

NormStatistics overBatch-dependent?Train = test?Lives in
BatchNormbatch, per channelyesno (running stats)convnets, large-batch vision
LayerNormfeatures, per tokennoyestransformers, RNNs
RMSNormfeatures, per token, no meannoyesmodern LLMs (Llama et al.)
Placement footnote: transformers also moved the norm from after the residual add (post-norm) to before the block (pre-norm), which keeps the residual path clean and makes very deep stacks trainable without warmup heroics. Axis and placement are the two decisions that matter.
Mental Model