Sampling a category without killing the gradient
Suppose a network must make a discrete choice mid-forward-pass — pick a codebook entry, a token, an architecture op — with probabilities π = (π₁, …, πK) that it computed itself, and then learn from what happens downstream. Two walls go up at once:
The blunt alternatives are both painful: REINFORCE-style score-function estimators are unbiased but notoriously high-variance, and simply replacing the sample with the soft probabilities forgoes sampling altogether. We want a sample that behaves like a function of π.
First, a curious identity. Draw K independent Gumbel noises gi = −log(−log ui), with ui ∼ Uniform(0,1). Then:
The one-line intuition: the Gumbel distribution is precisely the noise whose maxima behave like log-probability races, so perturbing each log-score with Gumbel noise and taking the winner reproduces categorical sampling exactly (the same family governs extreme values generally, which is where the distribution comes from). The payoff is structural — this is a reparameterisation, the same move the VAE makes for Gaussians: all the randomness now enters through an independent noise source g, and the path from the parameters π to the outcome is a deterministic function. Only one obstacle remains on that path: the argmax.
Replace the hard argmax with a softmax at temperature τ, applied to the same perturbed logits:
The output y is no longer a one-hot vector but a point on the simplex — a "soft sample" that is differentiable in π everywhere. As τ → 0 the softmax sharpens into the argmax and y recovers exact categorical samples; as τ grows, y melts toward uniform. We have traded exactness for a gradient: the relaxation is biased (downstream computations see a mixture of categories, which no real sample is) but smooth.
The same Gumbel-perturbed logits, squashed at three temperatures: uniform-ish, soft, and effectively discrete.
| Temperature | Samples look like | Gradients are | Cost |
|---|---|---|---|
| High τ | Smooth blends, near uniform | Low variance, well behaved | Heavily biased — nothing downstream sees a real category |
| Low τ | Near one-hot, nearly faithful | High variance, spiky (the function is almost a step) | Bias vanishes as τ → 0 |
The standard recipe is to anneal: start warm so learning gets clean gradients while the categories are still being sorted out, and cool toward discreteness as training stabilises. (In practice many systems just fix τ ≈ 0.5–1 and accept the bias.)
The straight-through variant takes both ends of the trade at once: in the forward pass, harden y to a true one-hot (downstream sees a real category); in the backward pass, pretend the softmax was used and pass its gradient through. The estimator is biased twice over, but the forward/backward mismatch is small at low τ, and it is often the most stable option when the downstream computation genuinely requires discrete inputs (a codebook lookup, say).