The Gumbel-max trick is a useful technique for sampling from a discrete distribution, especially when built-in statistical functions are unavailable. Although many statistical software packages can handle a wide range of distributions, there are situations where you need to sample from distributions that may not be explicitly supported.
In this blog, we will discuss the Gumbel-max trick, its mathematical proof, and a simulation example. Before diving into the trick, let’s do something easy first. We will briefly review the inverse cumulative distribution function (CDF) method for sampling from a distribution, which applies to both continuous and discrete distributions. In contrast, the Gumbel-max trick is specifically applicable to the discrete distributions.
Inverse CDF method
Definition: The CDF of a random variable \(Z\) is given by \(F_Z(z) = P(Z \leq z)~\forall~z\) where \(P(\cdot)\) denotes the probability. The CDF is a non-decreasing function defined on the real line \(\mathbb{R}\). Definition: For a non-decreasing function \(F\) on \(\mathbb{R}\), the generalized inverse is defined as \(F^{-1}(u) = min\{z: F(z) \geq u\}, u \in \mathbb{R}\). For a continuous random variable the inequality “\(\geq\)” become equality “\(=\)”. Suppose that we want to sample \(Z \sim P(z)\), and we have access to inverse CDF of \(Z\), then Lemma: If \(U \sim \text{Uniform}(0, 1)\), then \(Z = F_Z^{-1}(U)\) is a simulation from \(P(z)\). Proof: Consider \(P(Z \leq z) = P(F_Z^{-1}(U) \leq z) = P(U \leq F_Z(z)) = \int_{0}^{F_Z(z)} 1~\text{d}u = F_Z(z)\). Hence, \(Z\) follows the desired distribution \(P(z)\). We can observe that the inverse CDF method can be used to generate samples of \(Z\) through the transformation of a uniform random variable where the transformation is the inverse CDF of \(Z\) itself. For most of the statistical software, it is easier to generate samples from a uniform distribution.
A Coding example
Suppose \(Z\) is a categorical random variable that is \(Z\) can take values from \(\{1, 2, \dots, K\}\) with probabilities \(P(Z=k)=p_k\) where \(K\) denotes the categories. We would like to generate various realizations (samples) of \(Z\) (e.g. \(\{Z_1, Z_2, \cdots, Z_n\}\)) from this distribution. We can employ the inverse CDF method. Please refer to the Python code example for more clarity, come back, and read again.
Gumbel-max trick
Definition: The Gumbel distribution of a random variable \(X \in \mathbb{R}\) is given by \(f_X(x) = \frac{1}{\beta}\exp(-(x - \mu) - \exp(-(x - \mu)/\beta))\), the CDF is \(F_X(x) = P(X \leq x) = \exp(-\exp(-(x - \mu)/\beta))\), and the inverse CDF is \(-\log(\log(x - \mu)/\beta)\). To the scope of this discussion, it is enough to consider \(\mu=0, \beta=1\), that is \(X \sim \text{Gumbel}(0, 1)\). The Gumbel-max trick converts the sampling problem into an optimization problem and the key result is as follows. \[\begin{align}
\tag{1}
I = \underset{k}{\text{argmax}}~(\log(p_k) + G_k) \sim \text{Categorical}(p_k), 1 \leq k \leq K
\end{align}\] where \(G_k \sim \text{Gumbel}(0, 1)\) are Gumbel-distributed i.i.d random variables, and \(I\in \mathbb{Z}\) is a random variable because of \(G_k\). Eq. (1) states that the index \(I\) follows the desired categorical distribution (see proof), this index is obtained by adding samples from the Gumbel distribution (the Gumbel noise) to the log probabilities and choosing the index for which the quantity is maximum. Proof: Let \(X_k = \log(p_k) + G_k \in \mathbb{R}\). If the solution of Eq. (1) is the index \(k\), then the event \(\{I = k\}\) is equivalent to the event \(\{X_k > X_j~\forall j \neq k\}\) in the probability space. We can further translate it to another equivalent event “\(\{\text{sum over all}~x~(X_k = x, X_j < x~\forall j \neq k)\}\)”. Please note that \(G_k\)’s are i.i.d and consequently \(X_k\)’s are also i.i.d. Hence in probability space, we have \[\begin{align}
P(I = k) &= P(X_k > X_j~\forall k \neq j) = \int P(X_k = x, X_j < x~\forall j \neq k)~\mathrm{d}x \\ \nonumber
&=\int P(X_k = x) P(X_j < x~\forall~j \neq k|X_k = x)~\mathrm{d}x \\ \nonumber
&= \int P(X_k = x) P(X_j < x~\forall~j \neq k)~\mathrm{d}x = \int P(X_k = x) \prod_{j \neq k} P(X_j < x)~\mathrm{d}x \\ \nonumber
&= \int P(G_k = x - log(p_k)) \prod_{j \neq k} P(G_j < x - \log(p_j))~\mathrm{d}x \\ \nonumber
&= \int \exp(\log(p_k) - x - \exp(\log(p_k) - x)) \prod_{j \neq k} \exp(-\exp(\log(p_j) - x))~\mathrm{d}x \\ \nonumber
&= \int \exp(\log(p_k))\exp(-x)\exp(-\exp(\log(p_k) - x))\prod_{j \neq k} \exp(-\exp(\log(p_j) - x))~\mathrm{d}x \\ \nonumber
&= \int p_k \exp(-x)\prod_{k} \exp(-\exp(\log(p_k) - x))~\mathrm{d}x \\ \nonumber
&= \int p_k \exp(-x) \prod_{k} \exp(-p_k\exp(-x)))~\mathrm{d}x \\ \nonumber
&= p_k \int \exp(-x) \exp(-\sum_{k} p_k \exp(-x)))~\mathrm{d}x \\ \nonumber
&= p_k \int_{-\infty}^{+\infty} \exp(-x) \exp(-1\exp(-x)))~\mathrm{d}x = p_k
\end{align}\] This completes the proof. Hence the random variable \(I \sim \text{Cat}(p_k)\). To generate samples using this method, you only need to sample Gumbel-distributed variables \(G_k\) and find the index that maximizes \(\log(p_k) + G_k\). In contrast to inverse CDF method, this avoids computing the CDF or performing a binary search. To sample Gumbel-distributed variables, you can start with a uniform random variable and transform it using the inverse Gumbel CDF \((-\log(-\log(-x)))\) which is a closed-form expression, the resulting random variable will follow the Gumbel distribution. We can now generate various realizations of \(I\), i.e. \(\{I_1, I_2, \dots, I_n\}\) by adding each time a different realization of Gumbel noise to the log-probabilities and picking up on the maximum index. Please refer to the Python code.
Application to deep learning
The Gumbel-max trick finds significant application in deep learning, particularly in sampling from categorical distributions during optimization. Below are two common issues in deep learning that the trick helps address.
Issue-1
Optimizing categorical probabilities \(p_k\) often requires solving a constrained problem where \(\sum_k p_k = 1\) which is hard to optimize using gradient descent. To overcome this, one simple is to reparameterize \(p_k\) using another variable \(\theta_k \in \mathbb{R}\) through softmax.
\(\theta_k\) could be the logits from a neural network. This formulation allows for unconstrained optimization since the logits can take any real value. However, in some scenarios, normalization can be either computationally expensive or may not be possible. You can think of a streaming automatic speech recognition system, where not all the logits are available. Gumbel-Max trick can overcome this. Following the proof along the similar lines as above, it can be shown that \(\underset{k}{\text{argmax}}\{\theta_k + G_k\} \sim \text{Cat}(p_k)\) (You can do it as an exercise). This result remains valid for each \(\theta_k\)’s.
Issue-2
Neural networks typically rely on gradient-based optimization, which requires all operations to be differentiable. Both the inverse CDF method and the Gumbel-max trick involve non-differentiable operations.
The inverse CDF method introduces discontinuities when selecting categories based on cumulative probabilities.
The Gumbel-max trick involves the \(\text{argmax}\) operation, which is non-differentiable.
While many researchers have proposed relaxation for the differentiability of both methods, the Softmax variant of the Gumbel-max trick “Gumbel-softmax trick” is often used. It replaces the \(\text{argmax}\) with a differentiable approximation (softmax) that allows gradients to flow through the sampling process, making it suitable for backpropagation during neural network training. We leave this for future discussion (reference: Gumbel-Softmax).
Code
import numpy as npimport matplotlib.pyplot as pltfrom time import time# Inverse CDF Methoddef inverse_cdf(prob): u = np.random.uniform() cdf_of_z = np.cumsum(prob)# using generalized inverse, we need to find the minimum index j such that F(j) >= u for all u, for this we use numpy function "searchsorted" sample = np.searchsorted(cdf_of_z, u) # this yields the samples of Zreturn sample# Gumbel-max trickdef gumbel_max(prob): log_p = np.log(prob) u = np.random.uniform(0.0, 1.0, len(log_p)) G =-np.log(-np.log(u)) # use inverse Gumbel CDF to sample from Gumbel distribution given the uniform random variables I = np.argmax(prob + G)return Idef est_prob(samples, n_samples):return np.bincount(samples) / n_samplesprob = [0.1, 0.2, 0.3, 0.4] # the pk's for K = 4n_samples =5000start_g = time()samples_gumbel = [gumbel_max(prob) for _ inrange(n_samples)] # generate samples using Gumbelend_g = time() - start_gprint("--"*20)print(f"Time taken by Gumbel-max: {end_g:0.4f} Sec")start_cdf = time()samples_invcdf = [inverse_cdf(prob) for _ inrange(n_samples)] # generate samples using inverse CDF print(f"Time taken by inverse CDF: {time() - start_cdf:0.4f} Sec")# we can compute the estimated probabilites from the samples of Zest_prob_gumbel = est_prob(samples_gumbel, n_samples)est_prob_invcdf = est_prob(samples_invcdf, n_samples) # more number of samples (n_samples) will give better estimateserror_gumbel =abs(prob - est_prob_gumbel)error_invcdf =abs(prob - est_prob_invcdf)bar_width =0.1K = np.arange(len(prob)) # number of categoriesfig, ax = plt.subplots(2,1)ax[0].bar(K - bar_width, prob, bar_width, label="ORG. PROB.")ax[0].bar(K, est_prob_invcdf, bar_width, color='brown', label="EST. PROB. (Inv CDF)")ax[0].bar(K + bar_width, est_prob_gumbel, bar_width, color='green', label="EST. PROB. (Gumbel)")ax[0].set_ylim([0, 1])ax[0].set_ylabel("PROBABILITIES")ax[1].bar(K, error_invcdf , bar_width, color='brown', label="ERROR (Inv CDF)")ax[1].bar(K + bar_width, error_gumbel, bar_width, color='green', label="ERROR (Gumbel)")ax[1].set_xlabel("CATEGORIES")ax[1].set_ylabel("ERROR")ax[0].legend()ax[1].legend()print(f"Avg. error Gumbel-max: {np.sum(error_gumbel)/len(prob):0.4f}")print(f"Avg. error inv CDF: {np.sum(error_invcdf)/len(prob):0.4f}")
----------------------------------------
Time taken by Gumbel-max: 0.4369 Sec
Time taken by inverse CDF: 0.3250 Sec
Avg. error Gumbel-max: 0.0753
Avg. error inv CDF: 0.0050