General ML

Gumbel-Softmax

Sampling a category without killing the gradient

01 · The problemDiscrete sampling is a wall in the graph

Suppose a network must make a discrete choice mid-forward-pass — pick a codebook entry, a token, an architecture op — with probabilities π = (π₁, …, πK) that it computed itself, and then learn from what happens downstream. Two walls go up at once:

  1. Sampling is not a function of π in any differentiable sense: nudge π₂ slightly and the drawn sample either does not change at all or jumps to a different category entirely. The derivative is zero almost everywhere and undefined at the jumps.
  2. argmax has the same disease: a piecewise-constant function whose gradient is zero wherever it exists. Backpropagation arrives at the choice and finds nothing to flow through.

The blunt alternatives are both painful: REINFORCE-style score-function estimators are unbiased but notoriously high-variance, and simply replacing the sample with the soft probabilities forgoes sampling altogether. We want a sample that behaves like a function of π.

02 · The trickGumbel-max: move the randomness out of the way

First, a curious identity. Draw K independent Gumbel noises gi = −log(−log ui), with ui ∼ Uniform(0,1). Then:

argmaxi ( log πi + gi )  ∼  Categorical(π)   — exactly, not approximately

The one-line intuition: the Gumbel distribution is precisely the noise whose maxima behave like log-probability races, so perturbing each log-score with Gumbel noise and taking the winner reproduces categorical sampling exactly (the same family governs extreme values generally, which is where the distribution comes from). The payoff is structural — this is a reparameterisation, the same move the VAE makes for Gaussians: all the randomness now enters through an independent noise source g, and the path from the parameters π to the outcome is a deterministic function. Only one obstacle remains on that path: the argmax.

03 · The relaxationSoften the argmax with temperature

Replace the hard argmax with a softmax at temperature τ, applied to the same perturbed logits:

yi = exp( (log πi + gi) / τ )  /  Σj exp( (log πj + gj) / τ )

The output y is no longer a one-hot vector but a point on the simplex — a "soft sample" that is differentiable in π everywhere. As τ → 0 the softmax sharpens into the argmax and y recovers exact categorical samples; as τ grows, y melts toward uniform. We have traded exactness for a gradient: the relaxation is biased (downstream computations see a mixture of categories, which no real sample is) but smooth.

τ = 5 · NEAR UNIFORM τ = 1 · SOFT SAMPLE τ = 0.1 · NEAR ONE-HOT ONE DRAW, THREE TEMPERATURES — SAME PERTURBED LOGITS

The same Gumbel-perturbed logits, squashed at three temperatures: uniform-ish, soft, and effectively discrete.

04 · The dialτ trades bias against gradient variance

TemperatureSamples look likeGradients areCost
High τSmooth blends, near uniformLow variance, well behavedHeavily biased — nothing downstream sees a real category
Low τNear one-hot, nearly faithfulHigh variance, spiky (the function is almost a step)Bias vanishes as τ → 0

The standard recipe is to anneal: start warm so learning gets clean gradients while the categories are still being sorted out, and cool toward discreteness as training stabilises. (In practice many systems just fix τ ≈ 0.5–1 and accept the bias.)

The straight-through variant takes both ends of the trade at once: in the forward pass, harden y to a true one-hot (downstream sees a real category); in the backward pass, pretend the softmax was used and pass its gradient through. The estimator is biased twice over, but the forward/backward mismatch is small at low τ, and it is often the most stable option when the downstream computation genuinely requires discrete inputs (a codebook lookup, say).

05 · Where it earns its keepDiscrete choices inside differentiable systems

One-sentence summary for interviews: Gumbel-max makes categorical sampling a deterministic function of parameters plus external noise; Gumbel-softmax replaces that function's argmax with a temperature-controlled softmax, buying differentiability at the price of a bias you control with τ.
Mental Model