If you have a foundational understanding of probability theory and wish to explore the connection between modeling distributions through neural networks, Variational Autoencoders (VAEs) offer a compelling framework. VAEs are a class of generative models that enable the encoding of high-dimensional data into a lower-dimensional space while utilizing a probabilistic framework. Unlike standard autoencoders, the probabilistic nature of VAEs makes them more powerful for various real-world tasks. Encoding data into a lower-dimensional space is crucial for several reasons, especially when dealing with complex data types like images and audio. For instance, consider an image containing objects such as trees, birds, and a person, with much of the background remaining blank. Despite the image’s many pixels, the essential information resides in a lower-dimensional space (the pixels corresponding to the objects). By capturing only the key features, we can simplify the representation of the image. Furthermore, lower-dimensional data requires less storage and computational resources, accelerating processes like image classification or recognition. Similarly, consider an audio waveform that conveys not only spoken words but also information about the speaker’s accent, gender, and emotion. This nuanced information often exists in a lower-dimensional space that isn’t directly observable. Typically, such underlying unobserved information can be modeled by using the concept of latent variables from the probability theory. To understand latent variables simply, consider the concept of happiness. Happiness is not something we can measure directly; instead, we infer it through observable indicators such as smiling, laughter, or self-reported satisfaction on surveys. In this context, happiness serves as a latent variable. Thus, for many applications, inferring latent variables from observations is of great interest. One effective technique for this is known as variational inference, which has strong connections to VAEs.
Variational inference
Let’s denote an observable variable by \(x \in \mathbb{R}^D\) or \(x \in \mathbb{Z}^D\) and the latent variable by \(z \in \mathbb{R}^L\), where \(L << D\). The VAE model is an encoder-decoder architecture, the encoder is responsible for encoding the information into the latent variable space and the decoder aims to reconstruct the original information back from the latent representations. Under the paradigm of generative modeling, the VAEs can be viewed as a latent variable model where we assume that the observable data \(x\) is generated from hidden latent variable \(z\) and we would like to model the probability distribution of \(z\) given \(x\): \(p(z|x)\), also known as posterior. This allows us to infer the characteristics of \(z\). In other words, we would like to compute \[\begin{align}
p(z|x) = \frac{p(x|z)p(z)}{\int p(x, z)\mathrm{d}z}
\end{align}\] However, directly computing this integral is intractable for many practical use cases since the numerator involves summing over all the possibilities of latent variables. VAEs introduce a solution through variational inference. Instead of directly computing \(p(z|x)\), VAE approximates this posterior using a simpler distribution \(q(z|x)\) making it as close as to the original distribution \(p(z|x)\).
To achieve this, one could minimize the KL divergence between these distributions of latent variable \(z\). \[\begin{align}
KL(q(z|x)||p(z|x)) = E_{z \sim q(z|x)} \ln \frac{q(z|x)}{p(z|x)}
\end{align}\] For simplicity, we will write \(E_{z \sim q(z|x)}\) as \(E_{q}\). The KL divergence can be rewritten as follows. \[\begin{align}
KL(q(z|x)||p(z|x)) &= E_{q}\ln \frac{q(z|x)p(x)}{p(x|z)p(z)} = E_q \ln \frac{q(z|x)}{p(x|z)p(z)} + \ln p(x)
\end{align}\] where we have used \(E_q \ln p(x) = \int \ln p(x) q(z|x)\mathrm{d}z = \ln p(x) \int q(z|x)\mathrm{d}z = \ln p(x) \cdot 1\). The second term in this equation does not depend on \(z\) and using the fact that \(KL(\cdot||\cdot) \geq 0\), we can minimize only the first term which is also the lower bound on log-likelihood: \[\begin{align}
\ln p(x) \geq - E_q \ln \frac{q(z|x)}{p(x|z)p(z)}.
\end{align}\] In other words, we can maximize the RHS in this inequality. This quantity is known as the Evidence Lower Bound (\(ELBO\)). Re-writing \(ELBO\): \[\begin{align}
ELBO &= - E_q \ln \frac{q(z|x)}{p(x|z)p(z)} = E_q \ln p(z) + E_q \ln p(x|z) - E_q \ln q(z|x)
\end{align}\] The first term is log-prior, the second term represents the faithfulness of the decoder-reconstruction from latent variable \(z\) back to original data \(x\), and the third term is the distribution over \(z\) given \(x\). In VAE, the distributions q(z|x) and p(x|z) are parameterized by using encoder and decoder networks, respectively. You can think of this pipeline \(x \overset{q_{\phi}(z|x)}{\longrightarrow} z \overset{p_{\theta}(x|z)}{\longrightarrow} \tilde x\), it would help you to keep the overall picture in mind. Here \(\phi\) and \(\theta\) represent the parameters of the encoder and decoder, respectively. You are free to choose any neural network architecture for these two blocks of VAE. Let’s discuss each of the distributions involved in \(ELBO\).
VAE encoder
We would like to choose \(q_{\phi}(z|x)\) to be mathematically tractable (after all that is why we introduced it). Typically, it is chosen to be a Gaussian. \[\begin{align}
q_{\phi}(z|x) \sim \mathcal{N}(z|\mu_{\phi}(x), \sigma_{\phi}^2(x)I);~ z \in \mathbb{R}^{L}, x \in \mathbb{R}^{D}
\end{align}\] where the mean \(\mu_{\phi}(x)\) and the variance \(\sigma_{\phi}^2(x)\) are the outputs of the encoder, and we have chosen the covariance matrix to be diagonal (\(I\) is an identity matrix) without the loss of generality. Since variance has to be non-negative, during implementation you can assume that the encoder outputs \(\log\) variance \(logvar\) and use \(\exp(\log {(logvar)})\) to get the variance on the linear scale if required, this trick avoids variance being negative.
Prior
The prior \(p(z)\) plays a crucial role in VAE. For our discussion, we model it using standard Gaussian (with mean 0 and variance 1), \[\begin{align*}
p(z) = \mathcal{N}(z|0, I)
\end{align*}\] where \(I\) denots an identity matrix of size \(L \times L\) since \(z \in \mathbb{R}^L\).
VAE decoder
The term \(p_{\theta}(x|z)\) represents the decoder likelihood for reconstructing \(x\) from \(z\) and hence it is also referred to as the reconstruction error. The form of this term depends on the type of random variable \(x\). - For categorical \(x\), we can model \(p_{\theta}(x|z)\) as categorical distribution parmeterized by the decoder probablities. - For continuous \(x\), we can use mean square error or binary/cross entropy as reconstruction error between the values of \(x\) and the predicted values from the decoder.
To understand further, let’s discuss these formulations of \(p_{\theta}(x|z)\) for the image generation task.
Formulation-1
Under this formulation, we assume that \(x\) is a categorical random vector \(x = [x_1, x_2, \dots, x_d, \dots, x_D], x_d \in [0, 1, \dots, L-1]\). For instance, you can think of an image \(x\) of size \(8 \times 8\) where each pixel can take a value in the range from \(0\) to \(16\) i.e. \(x_d \in [0, 1, \dots, 16]^{64}\), in this case, we have \(L=17, D=64\). The probability distribution of \(x_d\) is given by \[\begin{align*}
p(x_d) = \prod_{l=1}^{L-1} p_{dl}^{I_{\{x_d = l\}}},~0 \leq l \leq L - 1
\end{align*}\] where \[\begin{align}
I(x) =
\begin{cases}
1 & \text{if } x_d = l \\
0 & \text{if } x_d \neq l
\end{cases}
\end{align}\] is an indicator variable and \(p_{dl} = P(x_d=l)\). Assuming that \(x_d\)’s are i.i.d., then \[\begin{align*}
p_{\theta}(x|z) = \prod_{d=1}^{D}\prod_{l=1}^{L-1} p_{dl}^{I_{\{x_d = l\}}}
\end{align*}\] In this equation the probabilities \(p_{d} = [p_{d0}, \dots, p_{dl}, \dots, p_{dL-1}] \in \mathbb{R}^{L-1}\) are given by the output layer of the decoder i.e.\(p_d = \mathrm{softmax}(f_{\theta}(z))\) with \(f_{\theta}(z) \in \mathbb{R}^{L-1}\) and \(f_{\theta}(\cdot)\) representing the decoder.
Formulation-2
In formulation-1, we derived the reconstruction error \(p_{\theta}(x|z)\) for the discrete case i.e.\(x_d \in \mathbb{Z}^D\). In this formulation, we seek the treatment for the continuous case i.e.\(x_d \in \mathbb{R}^D\). Further, we assume that \(x_d\) is normalized i.e.\(x_d \in [0, 1]\). An appropriate choice is the binary-cross entropy with soft labels given by \(x_d\). \[\begin{align*}
\ln p_{\theta}(x|z) = - \sum_{d=1}^{D} x_d \ln \hat p_d(z)
\end{align*}\] where \(\hat p_d(z) \in \mathbb{R}^{L-1}\) are given by the decoder that is \(\hat p_d(z) = \mathrm{sigmoid}(f_{\theta}(z))\). We use \(\mathrm{sigmoid}(\cdot)\) to squash each pixel value between \(0\) and \(1\).
\(p_{\theta}(x|z)\) is the reconstruction error and depends on \(x\)
\(q_{\phi}(z|x) \sim \mathcal{N}(z|\mu_{\phi}(x), \sigma^2_{\phi}(x)I)\) is known as the variational posterior
VAE training and the reparametrization trick
With this understanding of each of the terms involved in \(ELBO\), let’s now get back to it. We need to solve the following optimization problem during training of VAE: \[\begin{align*}
\underset{\phi, \theta}{\mathrm{argmax}} ELBO(\phi, \theta) = \underset{\phi, \theta}{\mathrm{argmax}} E_{q_{\phi}}[\ln p(z) + \ln p_{\theta}(x|z) - \ln q_{\phi}(z|x)]
\end{align*}\] This requires us to take the gradients of \(ELBO(\phi, \theta)\) with respect to the parameters \(\phi\), and \(\theta\). Let’s denote the gradient operator by \(\nabla\). There is a hurdle for computing the gradient w.r.t. \(\phi\). We can not take the gradient operator inside the expectation. This is because the expectation itself involves a distribution that is a function of \(\phi\), mathematically speaking \(\nabla_{\phi} E_{q_{\phi}}\ln _{\phi}(z|x) \neq E_{q_{\phi}} \nabla_{\phi} \ln _{\phi}(z|x)\). In contrast, such an issue does not arise for \(\nabla_{\theta}\). We use the reparametrization trick to overcome the hurdle. The idea is to make the expectation independent of \(\phi\) by reparameterizing the distribution of \(z\). Consider \[\begin{align*}
z &= g_{\phi}(\epsilon, x) = \mu_{\phi}(x) + \sigma_{\phi}(x) \odot \epsilon~\mathrm{where}~\epsilon \sim p(\epsilon) = \mathcal{N}(\epsilon|0, I) \in \mathbb{R}^{L-1}
\end{align*}\] In other words, we sample another random variable \(\epsilon\) outside the training process and apply a linear transformation to it to get the value of \(z\). We can then re-write \(ELBO\) as follows. \[\begin{align*}
ELBO(\phi, \theta) = E_{p(\epsilon)}[\ln p(g_{\phi}(\epsilon, x)) + \ln p_{\theta}(x|g_{\phi}(\epsilon, x)) - \ln q_{\phi}(g_{\phi}(\epsilon, x)|x)]
\end{align*}\] We use Monte Carlo simulation for computing the expectation. \[\begin{align*}
ELBO(\phi, \theta) &\approx \frac{1}{K}\sum_{k=1}^{K} [\ln p(g_{\phi}(\epsilon^{(k)}, x)) + \ln p_{\theta}(x|g_{\phi}(\epsilon^{(k)}, x)) - \ln q_{\phi}(g_{\phi}(\epsilon^{(k)}, x)|x)]\\
\mathrm{where}~\quad \quad &\epsilon^{(k)} \sim \mathcal{N}(\epsilon|0, I);~ k \in \mathbb{Z}
\end{align*}\] Typically, we use a batch of examples \(\{x^{(n)}\}_{n=1}^N\) during training. \[\begin{align*}
ELBO\big(\phi, \theta; \{x^{(n)}\}_{n=1}^{N}\big) &\approx \frac{1}{KN} \sum_{n=1}^N \sum_{k=1}^{K} \bigg[\ln p\big(g_{\phi}(\epsilon^{(k)}, x^{(n)})\big) + \ln p_{\theta}\big(x^{(n)}|g_{\phi}(\epsilon^{(k)}, x^{(n)})\big) - \ln q_{\phi}\big(g_{\phi}(\epsilon^{(k)}, x^{(n)})|x^{(n)}\big)\bigg]\\
\mathrm{where}~\quad \quad &\epsilon^{(k)} \sim \mathcal{N}(\epsilon|0, I);~ k \in \mathbb{Z}
\end{align*}\] This concludes our discussion on the mathematical concepts underlying Variational Autoencoders. As is customary in my blogs, I have included a Python code snippet to reinforce your understanding of these ideas. I strongly encourage you to examine the code line by line and relate it to the mathematics discussed. Additionally, reviewing Gaussian distributions will enhance your comprehension of the code. You may be curious about the implications of not using the reparameterization trick and instead sampling directly from a normal distribution with the mean and variance given by the encoder’s outputs. In this context, it’s important to note that the \(ELBO\) involves expectations over this distribution. Without the reparameterization trick, sampling a random variable from a distribution that depends on optimization variables introduces stochasticity during backpropagation. This stochasticity complicates the gradient estimation process, leading to a high variance of the gradients. As a result, training the VAE becomes unstable without the reparameterization trick. To observe this effect, you can set use_rep=False in the Python code and analyze the behavior of the loss function. From the resulting loss curves, you will likely infer that the model fails to converge effectively. In contrast, employing the reparameterization trick stabilizes training and allows for more reliable convergence.
The Coding Example and Key Focus Areas
VAE Model Implementation: This example demonstrates the VAE model built from scratch for the MNIST dataset.
Educational Purpose: The code is purely for educational purposes. You should be able to relate it to the mathematics behind the probability distributions involved in the \(ELBO\).
Continuous Data Focus: We focus on the continuous case of data where \(x \in \mathbb{R}^D\). If you are interested in studying the modeling of reconstruction error for categorical data, you can visit VAE.
ELBO Implementation: Pay attention to the implementation of the \(ELBO\).
Image Synthesis: A few synthesized images are saved on disk if save_images is set to True in the configuration.
Hardware Compatibility: The code is general enough to be executed on a CPU or single/multiple GPUs.
Code
import torchimport torch.nn as nnfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import datasets, transformsimport numpy as npimport os, sysimport matplotlib.pyplot as pltimport torch.nn.functional as Fdevice = torch.device("cuda"if torch.cuda.is_available() else"cpu")print(f"Device: {device}, n_gpus: {torch.cuda.device_count()}")# access class objects as attributesclass AttributeDict(dict):def__getattr__(self, attr):returnself[attr]def__setattr__(self, attr, value):self[attr] = valueconfig = {"lr":1e-3, "batch_size":16, "EPS":1e-12, "epochs":50, 'loss_period':10, "max_patience":20, "save_images":False}config = AttributeDict(config)# MNIST dataset downloadclass DatasetMnist(Dataset):def__init__(self, train=True, transforms=None):super(DatasetMnist, self).__init__()self.transforms = transformsif train ==True:self.data = datasets.MNIST(root="./train", train=train, download=True, transform=None)else:self.data = datasets.MNIST(root="./test", train=train, download=True, transform=None)def__len__(self):returnlen(self.data)def__getitem__(self, idx): image, label =self.data[idx]ifself.transforms: image =self.transforms(image)return image, labeltransform = transforms.Compose([transforms.ToTensor()])full_train_data = DatasetMnist(train=True, transforms=transform)full_test_data = DatasetMnist(train=False, transforms=transform)np.random.seed(300)n_train = np.random.choice(len(full_train_data), 500)n_test = np.random.choice(len(full_test_data), 100)train_data = torch.utils.data.Subset(full_train_data, indices=n_train)val_data = torch.utils.data.Subset(full_test_data, indices=n_test)train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True)valid_loader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False)# test_loader = DataLoader(test_data, batch_size=config.batch_size, shuffle=False)print(f"MNIST -> Train: {len(train_data)}, Valid: {len(val_data)}")x_sample, y_sample =next(iter(train_loader))D = x_sample.shape[2] * x_sample.shape[3] # input dimH =512# hidden unitsL =16# latent dimPI = torch.Tensor(np.asarray(np.pi))EPS = config.EPSencoder_net = nn.Sequential(nn.Linear(D, H), nn.ReLU(), nn.Linear(H, H), nn.ReLU(), nn.Linear(H, 2* L))decoder_net = nn.Sequential(nn.Linear(L, H), nn.ReLU(), nn.Linear(H, H), nn.ReLU(), nn.Linear(H, D))def log_normal_diag(z, mu, logvar):# z: [B, L] log_p =-0.5* z.shape[-1] * torch.log(2* PI) -0.5* torch.sum(logvar, dim=-1) -0.5* torch.sum(((z - mu) **2) / torch.exp(logvar), dim=-1)return log_p #[B]def log_normal_standard(z):# z: [B, L] log_p =-0.5* torch.log(2* PI) -0.5* torch.sum(z**2, dim=-1)return log_p #[B]class Prior(nn.Module):def__init__(self):super(Prior, self).__init__()def sample(self):passdef log_prob(self, z): log_p = log_normal_standard(z)return log_p #[B]# VAEclass VAE(nn.Module):def__init__(self, encoder_net, decoder_net, use_rep=True):super(VAE, self).__init__()self.encoder = encoder_netself.decoder = decoder_netself.prior = Prior()self.use_rep = use_repdef sample_from_encoder(self, mu, logvar):assertnot logvar.isnan().any(), "logvar -> nan"assertnot mu.isnan().any(), "mu -> nan" std = torch.exp(0.5* logvar)assertnot std.isnan().any(), "std -> nan"ifself.use_rep: eps = torch.randn_like(std) z = mu + std * epselse: z = torch.normal(mu, std)assertnot z.isnan().any(), "z -> nan in sample_from_dec"return zdef loss(self, x, enc_log_p, log_prior, dec_p):# x --> Encoder: q(z|x) --> z --> Decoder: p(z|x)# ELBO = ln p(z) + Ez~q(z|x) ln p(z|x) - Ez~q(z|x) ln q(z|x) + = log_prior + dec_log_prob - enc_log_prob, x: 1 x D, z: 1 X L# Assumptions for the distributions: encoder: q(z|x) ~ Normal(z|mu(x), diag(var(x))), decoder: p(x|z), and p(z) ~ Normal(z;0, I)# enc_log_p = log_normal_diag(z, mu, logvar) # ln q(z|x), [B] dec_p = torch.clamp(dec_p, EPS, 1- EPS)# dec_log_p = torch.log() # ln p[x|z], clamp it to avoid log(0), [B, D]# RE = F.cross_entropy(dec_p, x, reduction='none')# RE = F.binary_cross_entropy(dec_p, x, reduction='none') # returns with -ve sign RE =- (x * torch.log(dec_p) + (1- x)* torch.log(1- dec_p)) RE = torch.sum(RE, dim=-1) #[B], note that x ia d-dim vector, the joint distribution p[x|z] is written as the product of distributions for each dimension and hence the summation for log probability# dec_log_p = torch.sum(dec_log_p, dim=-1) #[B], note that x ia d-dim vector, the joint distribution p[x|z] is written as the product of distributions for each dimension and hence the summation for log probability# log_prior = self.prior.log_prob(z) # p[z], shape=[B] KL = (log_prior - enc_log_p) neg_elbo = RE - KL #[B]return torch.mean(neg_elbo) # batch-wise avg lossdef forward(self, x): h_e =self.encoder(x) mu, logvar = torch.chunk(h_e, 2, dim=-1) # mu, logvar: [B, L] z =self.sample_from_encoder(mu, logvar) # [B, L], we use only one sample for Monte Carlo simulation h_d =self.decoder(z) dec_p = torch.sigmoid(h_d)# dec_p = torch.softmax(self.decoder(z), dim=-1) # [B, D] enc_log_p = log_normal_diag(z, mu, logvar) log_prior =self.prior.log_prob(z) # p[z], shape=[B]# loss = self.loss(z, mu, logvar, dec_p)return mu, logvar, enc_log_p, log_prior, dec_p# training model = VAE(encoder_net, decoder_net, use_rep=True)if torch.cuda.device_count() >1:print(f"Available GPUs: {torch.cuda.device_count()}") model = nn.DataParallel(model)model.to(device)# optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=config.lr)optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)def check(value):assertnot value.isnan().all(), f'{value} -> nan'assert value.isfinite().all(), f"{value} -> inf"def train(train_loader, model, optimizer): model.train() avg_loss, total_samples =0.0, 0for x, y in train_loader: batch_size = x.shape[0] x = x.view(-1, D) x = (x - torch.min(x)) / (torch.max(x) - torch.min(x)) x = x.to(device) mu, logvar, enc_log_p, log_prior, dec_p = model(x)if torch.cuda.device_count() >1: loss_ = model.module.loss(x, enc_log_p, log_prior, dec_p) # batch-wise avg losselse: loss_ = model.loss(x, enc_log_p, log_prior, dec_p) # batch-wise avg loss# check(loss_) optimizer.zero_grad() loss_.backward() optimizer.step() avg_loss += (loss_.item() * batch_size) total_samples += batch_size avg_loss /= total_samplesreturn avg_lossdef save_images(org, recons, epoch):""" org: [B, D] """ fig_name =f"vae_output_{epoch}.png" save_here ="./vae_results/"+ fig_name num_images = org.shape[0]assert num_images %2==0, "Can not save images, batch_size must be even!" org, recons = org[:num_images].view(-1, 28, 28), recons[:num_images].view(-1, 28, 28) img = torch.cat((org, recons), 0) img = img.detach().numpy() fig, ax = plt.subplots(2, num_images)for idx inrange(2* num_images): ax[idx // num_images, idx % num_images].imshow(img[idx]) ax[idx // num_images, idx % num_images].axis("off") plt.savefig(save_here, bbox_inches='tight')def validate(valid_loader, model, epoch): model.eval() avg_loss, total_samples =0.0, 0with torch.no_grad():for x, y in valid_loader: batch_size = x.shape[0] x = x.view(-1, D) x = (x - torch.min(x)) / (torch.max(x) - torch.min(x)) x = x.to(device) mu, logvar, enc_log_p, log_prior, dec_p = model(x)if torch.cuda.device_count() >1: loss_ = model.module.loss(x, enc_log_p, log_prior, dec_p) # batch-wise avg losselse: loss_ = model.loss(x, enc_log_p, log_prior, dec_p) # batch-wise avg loss avg_loss += (loss_.item() * batch_size) total_samples += batch_sizeif config.save_images ==True: save_images(x, dec_p, epoch) avg_loss /= total_samplesreturn avg_losstrain_loss, val_loss = [], []for epoch inrange(config.epochs): train_loss_ = train(train_loader, model, optimizer) val_loss_ = validate(valid_loader, model, epoch) train_loss.append(train_loss_) val_loss.append(val_loss_)if (epoch +1) % config.loss_period ==0:print(f"Epoch: {epoch +1}, Train Loss: {train_loss_:0.4f}, Val Loss: {val_loss_:0.4f}")plt.plot(train_loss, label='training')plt.plot(val_loss, label='validation')plt.legend()