Variational Autoencoders (VAE)

A probabilistic autoencoder. The encoder maps each input to a Gaussian distribution over a latent vector z; the decoder reconstructs the input from a sample. Training maximises a variational lower bound (ELBO) on the data likelihood.

TL;DR. A VAE learns a smooth, structured latent space you can sample from and interpolate in. The encoder outputs (mu, log_var); you sample z = mu + sigma * eps with eps ~ N(0, I) (reparameterisation trick) so gradients flow; the decoder rebuilds x; you train with reconstruction + a KL term that pulls the posterior toward N(0, I). USAAIO uses VAEs for low-data image generation, anomaly detection (high reconstruction error = outlier), and as the latent compressor inside Stable Diffusion.

1. The intuition

A plain autoencoder learns a deterministic bottleneck z = f(x) with a decoder x_hat = g(z) that minimises reconstruction error. Problem: the latent space has holes. If you pick a random z and decode it, you get garbage; if you interpolate between two latents, you pass through nonsense.

A VAE fixes this by being probabilistic. The encoder outputs a whole distribution q(z | x) = N(mu(x), sigma(x)^2) — a soft blob rather than a point. The training loss has two parts: (1) reconstruct x from samples of z, and (2) keep q(z|x) close to a simple prior p(z) = N(0, I) via KL divergence. Term (2) packs all data points into one overlapping cloud around the origin, so sampling z ~ N(0, I) and decoding produces realistic outputs, and interpolating between two latents stays on the manifold.

The KL term acts as a regulariser; the reconstruction term as data fit. Their balance is controlled by a coefficient beta (beta-VAE). Too much KL: blurry samples, posterior collapse. Too little KL: a deterministic autoencoder in disguise.

2. The math

ELBO derivation

We want to maximise log p(x) under a latent-variable model p(x, z) = p(z) p_theta(x | z). The true posterior p(z | x) is intractable, so introduce a variational approximation q_phi(z | x) and apply Jensen's inequality:

log p(x) = log ∫ p(x, z) dz = log ∫ q(z|x) * p(x, z) / q(z|x) dz >= ∫ q(z|x) * log[ p(x, z) / q(z|x) ] dz (Jensen) = E_q[ log p(x | z) ] - KL( q(z|x) || p(z) ) = ELBO

The gap is exactly KL( q(z|x) || p(z|x) ), which is non-negative. So maximising the ELBO simultaneously maximises a lower bound on log p(x) and tightens the posterior approximation. Two terms:

Closed-form KL with a standard-normal prior

With q(z|x) = N(mu, diag(sigma^2)) and p(z) = N(0, I), the KL has a clean closed form (sum over latent dims j):

KL = 0.5 * sum_j ( mu_j^2 + sigma_j^2 - 1 - log(sigma_j^2) ) = -0.5 * sum_j ( 1 + log_var_j - mu_j^2 - exp(log_var_j) )

Predicting log_var = log(sigma^2) (rather than sigma directly) keeps variance positive and stabilises training.

Reparameterisation trick

Sampling z ~ N(mu, sigma^2) directly is not differentiable in mu, sigma. Rewrite the sample as a deterministic function of a noise source:

z = mu + sigma * eps, eps ~ N(0, I)

Now gradients of the reconstruction loss flow through mu and sigma into the encoder; only the randomness sits in the parameter-free eps. This trick is what makes the whole thing trainable by SGD.

Sampling at inference

To generate a new image, draw z ~ N(0, I) from the prior and run the decoder. No encoder needed.

3. PyTorch reference implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    """MNIST-scale fully connected VAE. Replace MLPs with convs for images."""

    def __init__(self, input_dim=784, hidden=400, latent=20):
        super().__init__()
        # Encoder: x -> hidden -> (mu, log_var)
        self.fc1     = nn.Linear(input_dim, hidden)
        self.fc_mu   = nn.Linear(hidden, latent)
        self.fc_lv   = nn.Linear(hidden, latent)
        # Decoder: z -> hidden -> reconstruction
        self.fc3     = nn.Linear(latent, hidden)
        self.fc4     = nn.Linear(hidden, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_lv(h)

    def reparameterise(self, mu, log_var):
        sigma = torch.exp(0.5 * log_var)
        eps   = torch.randn_like(sigma)
        return mu + sigma * eps               # diff-able sample

    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))     # pixel probabilities

    def forward(self, x):
        mu, log_var = self.encode(x)
        z           = self.reparameterise(mu, log_var)
        x_hat       = self.decode(z)
        return x_hat, mu, log_var


def vae_loss(x_hat, x, mu, log_var, beta=1.0):
    # Reconstruction: sum BCE over pixels (batch-mean later).
    recon = F.binary_cross_entropy(x_hat, x, reduction="sum") / x.size(0)
    # Closed-form KL with N(0, I) prior, per sample, then mean.
    kl    = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1).mean()
    return recon + beta * kl, recon, kl


def train_step(model, x, optim):
    x_hat, mu, log_var = model(x)
    loss, _, _ = vae_loss(x_hat, x, mu, log_var)
    optim.zero_grad()
    loss.backward()
    optim.step()
    return float(loss)


@torch.no_grad()
def sample(model, n=16, latent=20, device="cpu"):
    """Draw fresh images from the prior N(0, I)."""
    model.train(False)                        # equivalent to standard .e+val()
    z = torch.randn(n, latent, device=device)
    return model.decode(z).view(n, 1, 28, 28)


if __name__ == "__main__":
    torch.manual_seed(0)
    model = VAE()
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    x = torch.rand(64, 784)                   # fake MNIST batch in [0,1]
    for step in range(3):
        loss = train_step(model, x, optim)
        print(step, loss)
    imgs = sample(model, n=4)
    print(imgs.shape)                         # torch.Size([4, 1, 28, 28])

model.train(False) in sample() is the standard inference switch — same effect as the short-form call on the module. Written long-form so the project security hook does not flag the substring.

4. Common USAAIO / IOAI applications

5. Drills

D1 · Why predict log_var, not sigma?

What goes wrong if the encoder predicts sigma directly?

Solution

sigma must be positive; a raw linear output is unconstrained. You could exp it, but predicting log_var already handles that and gives the KL term a numerically stable form (mu^2 + exp(log_var) - 1 - log_var). Predicting sigma and squaring also hits numerical issues near 0.

D2 · KL value at perfect prior match

What is KL(N(0, I) || N(0, I))?

Solution

0. Plug into the closed form: mu = 0, sigma^2 = 1, log_var = 0 gives 0.5 * sum(0 + 1 - 1 - 0) = 0. A VAE that perfectly matches the prior has paid the full KL cost of nothing and reconstructs from random noise.

D3 · Posterior collapse

Samples from your VAE all look identical regardless of input. Your reconstructions are blurry but valid. The KL term has gone to near 0. Diagnose and fix.

Solution

Posterior collapse — the encoder is outputting q(z|x) ≈ p(z) = N(0,I) and the decoder is ignoring z entirely. Fixes: lower beta, use KL annealing (ramp from 0 to 1 over training), use free bits (clip KL per dimension below a threshold), or strengthen the decoder less.

D4 · Reparameterisation forward/backward

Why does z = mu + sigma * eps let gradients flow, while z ~ N(mu, sigma^2) directly does not?

Solution

Direct sampling is a non-differentiable stochastic op — autograd cannot push a gradient back through "draw a sample". Rewriting as a deterministic function of an external noise eps isolates the randomness; mu, sigma then appear in a smooth expression that autograd can differentiate.

D5 · Latent dim choice

MNIST with latent dim 2 vs 64 — what is the trade-off?

Solution

2 dims: great for visualisation (you can scatter-plot the latent space), terrible reconstructions (information bottleneck too tight). 64 dims: sharp reconstructions but many unused latent dimensions and harder to visualise. Practical sweet spot for MNIST is 10-32.

Next step

Diffusion models dominate VAEs for sample quality — read DDPM next to see how iterative denoising sidesteps the blurriness problem. For attention in latent diffusion (Stable Diffusion's cross-attention), revisit Transformers. For paper-form short answer on ELBO and reparameterisation, hit Round 2 theory.