Training Diffusion Models from Scratch

CS5670 Project 5B - Introduction to Computer Vision
Cornell University, Spring 2025

Train your own diffusion model on the MNIST dataset from the ground up.

This project provides deep insights into diffusion model training by implementing UNet architectures, understanding DDPM mathematics, and progressing from simple one-step denoisers to full conditional diffusion models with classifier-free guidance.

MNIST Diffusion Model Training Progression
Training progression: From noisy samples to clean MNIST digit generation over 20 epochs

Project Details

Assigned: Thursday, April 17, 2025
Due: Tuesday, May 6, 2025 (Part B)
Individual Work: Must be completed individually
Platform: Google Colab with GPU acceleration
Key Concepts: UNet architecture, DDPM training, time conditioning, CFG

Overview

Training diffusion models from scratch provides invaluable insights into how these powerful generative models learn to transform noise into meaningful data. This project takes us through the complete journey from implementing UNet architectures to training state-of-the-art conditional diffusion models.

Working with the MNIST dataset allows us to focus on the core algorithms while maintaining reasonable training times. We progress from simple one-step denoisers to full DDPM implementation with time and class conditioning.

Part 1: Single-Step Denoising UNet

1.1 UNet Architecture Implementation

The UNet serves as the backbone of our diffusion model, featuring downsampling and upsampling paths with skip connections for multi-scale feature processing.

Unconditional UNet Architecture
UNet Architecture Diagram

Core UNet Components:

  1. Conv: Convolutional layer preserving spatial dimensions
  2. DownConv: Downsampling convolution (stride=2)
  3. UpConv: Upsampling transposed convolution (stride=2)
  4. Flatten: Average pooling to reduce 7×7 → 1×1
  5. Unflatten: Expand 1×1 → 7×7 through convolution
  6. Skip Connections: Concatenate features across scales
UNet Operations Detail
UNet Operations Diagram

1.2 Single-Step Denoiser Training

We begin with the fundamental denoising task: given a noisy image, predict the clean version in a single step.

Single-Step Denoising Objective:

Single-Step Denoising Loss Function

Noise Generation Process:

Noise Generation Mathematics

MNIST Training Setup

Training Configuration:

Dataset: MNIST (60,000 training images)
Batch Size: 256
Epochs: 5
Optimizer: Adam (lr=1e-4)
Hidden Dimension: D = 128
Noise Level: σ = 0.5

Noise Level Visualization

MNIST σ=0.0
σ = 0.0
MNIST σ=0.2
σ = 0.2
MNIST σ=0.4
σ = 0.4
MNIST σ=0.6
σ = 0.6
MNIST σ=0.8
σ = 0.8
MNIST σ=1.0
σ = 1.0

Training Results

Single-Step Training Loss Curve
Training Loss Curve
5 epochs of single-step denoiser training

Epoch 1

Single-Step Results Epoch 1
Initial denoising capability

Epoch 5

Single-Step Results Epoch 5
Converged denoising performance

Out-of-Distribution Testing

Testing the single-step denoiser on noise levels it wasn't trained on reveals the limitations of this approach and motivates iterative diffusion.

OOD σ=0.1
σ = 0.1
OOD σ=0.3
σ = 0.3
OOD σ=0.7
σ = 0.7
OOD σ=0.9
σ = 0.9

Part 2: Full DDPM Training

2.1 Time-Conditioned UNet Architecture

Adding Temporal Conditioning

To enable iterative denoising, we condition the UNet on timestep t, allowing it to adapt its behavior based on the current noise level.

Time-Conditioned UNet
Time-Conditioned UNet Architecture
FCBlock for Time Embedding
FCBlock Architecture
# Time conditioning implementation fc1_t = FCBlock(1, hidden_dim) # Scalar timestep input fc2_t = FCBlock(1, hidden_dim) # Normalize timestep to [0, 1] t_normalized = t / 300.0 # Embed and inject conditioning t1 = fc1_t(t_normalized) t2 = fc2_t(t_normalized) # Modulate UNet features unflatten = unflatten + t1 up1 = up1 + t2

2.2 DDPM Mathematics and Variance Schedule

DDPM requires careful design of the noise schedule that controls how noise is added during training and removed during sampling.

DDPM Training Objective:

DDPM Training Loss Function

Forward Process with Schedule:

DDPM Forward Process Mathematics

DDPM Variance Schedule (T=300):

β schedule: Linear from 1e-4 to 0.02
α_t: 1 - β_t
ᾱ_t: Cumulative product of α values
Properties: ᾱ_0 ≈ 1 (clean), ᾱ_T ≈ 0 (pure noise)

2.3 DDPM Training Algorithm

Training Process

DDPM Training Algorithm:

  1. Sample: Get clean image x₀ from MNIST dataset
  2. Random Timestep: Sample t uniformly from [0, T-1]
  3. Add Noise: Create x_t using forward process
  4. Predict: Use UNet to predict noise ε given x_t and t
  5. Loss: Compute L2 loss between predicted and true noise
  6. Backprop: Update UNet parameters

DDPM Training Configuration:

Timesteps: T = 300
Batch Size: 128
Epochs: 20
Optimizer: Adam (lr=1e-3)
LR Schedule: Exponential decay (γ=0.99)
Hidden Dimension: D = 64

2.4 DDPM Sampling Algorithm

The sampling process iteratively denoises pure noise into coherent MNIST digits using the learned reverse process.

DDPM Sampling Algorithm:

  1. Initialize: Start with pure noise x_T ~ N(0, I)
  2. Reverse Process: For t = T-1 down to 0:
  3. Predict Noise: ε_θ(x_t, t) using trained UNet
  4. Compute x_0: Estimate clean image from current state
  5. Step Back: Calculate x_{t-1} using DDPM reverse formula
  6. Add Variance: Include stochastic component for non-deterministic sampling

DDPM Reverse Process:

DDPM Reverse Process Mathematics

Time-Conditioned Results

Time-Conditioned Training Loss
Time-Conditioned Training Loss
20 epochs with exponential LR decay

Epoch 5

Time-Conditioned Epoch 5
Early training progress

Epoch 20

Time-Conditioned Epoch 20
Final generation quality

Part 2.4: Class-Conditioned Diffusion

Adding Class Conditioning for Controlled Generation

By conditioning on digit classes (0-9), we gain precise control over what the model generates while enabling classifier-free guidance.

Class and Time Conditioned UNet
Class-Conditioned UNet Architecture
# Class conditioning implementation fc1_t = FCBlock(1, hidden_dim) # Time embedding fc1_c = FCBlock(10, hidden_dim) # Class embedding (one-hot) fc2_t = FCBlock(1, hidden_dim) fc2_c = FCBlock(10, hidden_dim) # Conditional dropout (10% unconditional training) if random.random() < 0.1: c = torch.zeros_like(c) # Drop class conditioning # Embed both conditions t1 = fc1_t(t_normalized) c1 = fc1_c(c) t2 = fc2_t(t_normalized) c2 = fc2_c(c) # Multiplicative and additive conditioning unflatten = c1 * unflatten + t1 up1 = c2 * up1 + t2

Class-Conditional Training:

  1. One-Hot Encoding: Convert digit labels to 10-dimensional vectors
  2. Conditional Dropout: 10% of time, set class vector to zero
  3. Dual Conditioning: Inject both time and class information
  4. Mixed Training: Learn both conditional and unconditional generation

Classifier-Free Guidance (CFG)

Enhanced Generation Quality

CFG combines conditional and unconditional predictions to improve generation quality and adherence to class conditioning.

CFG for Class-Conditioned Generation:

Class-Conditioned CFG Mathematics

CFG Sampling Process:

  1. Conditional Prediction: ε_θ(x_t, t, c) with class c
  2. Unconditional Prediction: ε_θ(x_t, t, ∅) with empty class
  3. Guidance Scale: γ = 5.0 for enhanced quality
  4. Combined Estimate: Extrapolate beyond conditional prediction

Class-Conditioned Results

Class-Conditioned Training Loss
Class-Conditioned Training Loss
20 epochs with class and time conditioning

Generated Digits by Class

Epoch 5

Generated 0
0
Generated 1
1
Generated 2
2
Generated 3
3
Generated 4
4
Generated 5
5
Generated 6
6
Generated 7
7
Generated 8
8
Generated 9
9

Epoch 20 (Final)

Final Generated 0
0
Final Generated 1
1
Final Generated 2
2
Final Generated 3
3
Final Generated 4
4
Final Generated 5
5
Final Generated 6
6
Final Generated 7
7
Final Generated 8
8
Final Generated 9
9

My Results

Implementation Achievements

Successfully implemented and trained three different diffusion model variants from scratch, demonstrating deep understanding of UNet architectures, DDPM mathematics, and conditional generation techniques.

Model Comparison

Single-Step Denoiser

Single-Step Results

Fast but limited to specific noise levels

Time-Conditioned DDPM

Time-Conditioned Results

Better quality through iterative refinement

Class-Conditioned DDPM + CFG

Class-Conditioned Results

Highest quality with precise control

Training Insights and Challenges

Extra Credit Explorations

Advanced Techniques

Improved UNet Architecture
Enhanced UNet
Additional skip connections and attention
Rectified Flow Results
Rectified Flow
State-of-the-art sampling efficiency

UNet Architecture Improvements:

Rectified Flow Implementation:

Replaced DDPM's complex variance schedule with straight-line interpolation between noise and data, achieving faster convergence and improved sample quality with fewer sampling steps.

Key Learnings

Deep Learning and Architecture Design

Diffusion Model Theory

Practical Implementation Skills

Impact and Applications

Foundational Understanding

Research Foundation: Deep understanding of generative modeling principles for future research

Architecture Design: Skills in designing and implementing neural network architectures

Training Methodologies: Experience with complex training procedures and optimization strategies

Conditional Generation: Understanding how to control generative models for specific applications

Scalability Principles: Knowledge of how these techniques scale to larger, more complex datasets

State-of-the-Art Connection: Direct experience with the foundations underlying modern AI systems