Variational Autoencoder on MNIST

Learning a structured latent space for image generation with PyTorch

Introduction

A standard autoencoder maps an input to a fixed latent vector and reconstructs it. The problem is the latent space has no enforced structure. Two nearby latent vectors can decode to completely different outputs, which makes sampling or interpolation unreliable. You can verify this yourself: take a trained autoencoder, sample a random latent vector, and decode it. Most of the time you get noise, not a digit.

A VAE maps each input to a distribution over latent space, parameterized by a mean vector and a log-variance vector. Sampling from that distribution forces the model to learn a smooth, continuous latent space. The reparameterization trick keeps it differentiable: instead of sampling z directly, compute z = μ + σ·ε where ε ~ N(0, I), so gradients flow through the stochastic step normally during backprop.

The regularization comes from the KL divergence term in the loss, which penalizes the encoder for producing distributions that deviate from a standard normal. This is what forces structure into the space. Without it, the encoder can set σ to near zero and collapse to a deterministic autoencoder, using the mean as a fixed code with no stochasticity and no continuous structure.

Architecture & Training

The model operates on flattened 28×28 MNIST images (784 dimensions):

  • Encoder: Linear(784→512) + ReLU, then two parallel output heads: fc_mu and fc_logvar, each outputting a 40-dimensional vector. Log-variance is used instead of variance directly because it can be any real number, which makes it easier for the network to learn.
  • Latent space: Dimension 40. The reparameterization trick samples z = μ + exp(0.5·log_var)·ε from the encoded distribution.
  • Decoder: Linear(40→512) + ReLU → Linear(512→784) + Sigmoid, reconstructing pixel values in [0, 1].

Loss function: Binary Cross-Entropy for reconstruction, plus KL divergence for regularization. The KL term has a closed-form expression for Gaussian distributions: -0.5 * sum(1 + log_var - mu² - exp(log_var)), so no sampling is needed to compute it.

KL annealing is applied over the first 10 epochs, increasing the KL weight gradually from 0 to 1. Without annealing, the KL term can dominate early in training and push the encoder to collapse to the prior before it has learned anything useful. Annealing gives the reconstruction loss time to establish a meaningful encoding first.

Optimizer: Adam with lr = 1e-3, paired with CosineAnnealingLR over 30 training epochs to reduce the learning rate smoothly.

Generation & Notebook

Training loss converges from 0.2746 down to 0.2628 over 30 epochs. The model reconstructs held-out MNIST digits clearly and generates novel digit images by sampling random latent vectors from N(0, I) and decoding them. The samples are recognizable as digits rather than noise, which confirms the latent space is well-structured.

The structured latent space lets you interpolate between two latent vectors and get smooth transitions between digits, something a regular autoencoder cannot reliably do. For example, interpolating linearly between the latent code for a 3 and the code for an 8 produces intermediate images that look like plausible but ambiguous digits. This is the direct result of the KL regularization filling in the gaps between training examples with a coherent, navigable space.