Train your own diffusion model on the MNIST dataset from the ground up.
This project provides deep insights into diffusion model training by implementing UNet architectures, understanding DDPM mathematics, and progressing from simple one-step denoisers to full conditional diffusion models with classifier-free guidance.
Training progression: From noisy samples to clean MNIST digit generation over 20 epochs
Project Details
| Assigned: |
Thursday, April 17, 2025 |
| Due: |
Tuesday, May 6, 2025 (Part B) |
| Individual Work: |
Must be completed individually |
| Platform: |
Google Colab with GPU acceleration |
| Key Concepts: |
UNet architecture, DDPM training, time conditioning, CFG |
Overview
Training diffusion models from scratch provides invaluable insights into how these powerful generative models learn to transform noise into meaningful data. This project takes us through the complete journey from implementing UNet architectures to training state-of-the-art conditional diffusion models.
Working with the MNIST dataset allows us to focus on the core algorithms while maintaining reasonable training times. We progress from simple one-step denoisers to full DDPM implementation with time and class conditioning.
Part 1: Single-Step Denoising UNet
1.1 UNet Architecture Implementation
The UNet serves as the backbone of our diffusion model, featuring downsampling and upsampling paths with skip connections for multi-scale feature processing.
Unconditional UNet Architecture
Core UNet Components:
- Conv: Convolutional layer preserving spatial dimensions
- DownConv: Downsampling convolution (stride=2)
- UpConv: Upsampling transposed convolution (stride=2)
- Flatten: Average pooling to reduce 7×7 → 1×1
- Unflatten: Expand 1×1 → 7×7 through convolution
- Skip Connections: Concatenate features across scales
UNet Operations Detail
1.2 Single-Step Denoiser Training
We begin with the fundamental denoising task: given a noisy image, predict the clean version in a single step.
Single-Step Denoising Objective:
Noise Generation Process:
MNIST Training Setup
Training Configuration:
Dataset: MNIST (60,000 training images)
Batch Size: 256
Epochs: 5
Optimizer: Adam (lr=1e-4)
Hidden Dimension: D = 128
Noise Level: σ = 0.5
Noise Level Visualization
σ = 0.0
σ = 0.2
σ = 0.4
σ = 0.6
σ = 0.8
σ = 1.0
Training Results
Training Loss Curve
5 epochs of single-step denoiser training
Epoch 1
Initial denoising capability
Epoch 5
Converged denoising performance
Out-of-Distribution Testing
Testing the single-step denoiser on noise levels it wasn't trained on reveals the limitations of this approach and motivates iterative diffusion.
σ = 0.1
σ = 0.3
σ = 0.7
σ = 0.9
Part 2: Full DDPM Training
2.1 Time-Conditioned UNet Architecture
Adding Temporal Conditioning
To enable iterative denoising, we condition the UNet on timestep t, allowing it to adapt its behavior based on the current noise level.
Time-Conditioned UNet
FCBlock for Time Embedding
# Time conditioning implementation
fc1_t = FCBlock(1, hidden_dim) # Scalar timestep input
fc2_t = FCBlock(1, hidden_dim)
# Normalize timestep to [0, 1]
t_normalized = t / 300.0
# Embed and inject conditioning
t1 = fc1_t(t_normalized)
t2 = fc2_t(t_normalized)
# Modulate UNet features
unflatten = unflatten + t1
up1 = up1 + t2
2.2 DDPM Mathematics and Variance Schedule
DDPM requires careful design of the noise schedule that controls how noise is added during training and removed during sampling.
DDPM Training Objective:
Forward Process with Schedule:
DDPM Variance Schedule (T=300):
β schedule: Linear from 1e-4 to 0.02
α_t: 1 - β_t
ᾱ_t: Cumulative product of α values
Properties: ᾱ_0 ≈ 1 (clean), ᾱ_T ≈ 0 (pure noise)
2.3 DDPM Training Algorithm
Training Process
DDPM Training Algorithm:
- Sample: Get clean image x₀ from MNIST dataset
- Random Timestep: Sample t uniformly from [0, T-1]
- Add Noise: Create x_t using forward process
- Predict: Use UNet to predict noise ε given x_t and t
- Loss: Compute L2 loss between predicted and true noise
- Backprop: Update UNet parameters
DDPM Training Configuration:
Timesteps: T = 300
Batch Size: 128
Epochs: 20
Optimizer: Adam (lr=1e-3)
LR Schedule: Exponential decay (γ=0.99)
Hidden Dimension: D = 64
2.4 DDPM Sampling Algorithm
The sampling process iteratively denoises pure noise into coherent MNIST digits using the learned reverse process.
DDPM Sampling Algorithm:
- Initialize: Start with pure noise x_T ~ N(0, I)
- Reverse Process: For t = T-1 down to 0:
- Predict Noise: ε_θ(x_t, t) using trained UNet
- Compute x_0: Estimate clean image from current state
- Step Back: Calculate x_{t-1} using DDPM reverse formula
- Add Variance: Include stochastic component for non-deterministic sampling
DDPM Reverse Process:
Time-Conditioned Results
Time-Conditioned Training Loss
20 epochs with exponential LR decay
Epoch 5
Early training progress
Epoch 20
Final generation quality
Part 2.4: Class-Conditioned Diffusion
Adding Class Conditioning for Controlled Generation
By conditioning on digit classes (0-9), we gain precise control over what the model generates while enabling classifier-free guidance.
Class and Time Conditioned UNet
# Class conditioning implementation
fc1_t = FCBlock(1, hidden_dim) # Time embedding
fc1_c = FCBlock(10, hidden_dim) # Class embedding (one-hot)
fc2_t = FCBlock(1, hidden_dim)
fc2_c = FCBlock(10, hidden_dim)
# Conditional dropout (10% unconditional training)
if random.random() < 0.1:
c = torch.zeros_like(c) # Drop class conditioning
# Embed both conditions
t1 = fc1_t(t_normalized)
c1 = fc1_c(c)
t2 = fc2_t(t_normalized)
c2 = fc2_c(c)
# Multiplicative and additive conditioning
unflatten = c1 * unflatten + t1
up1 = c2 * up1 + t2
Class-Conditional Training:
- One-Hot Encoding: Convert digit labels to 10-dimensional vectors
- Conditional Dropout: 10% of time, set class vector to zero
- Dual Conditioning: Inject both time and class information
- Mixed Training: Learn both conditional and unconditional generation
Classifier-Free Guidance (CFG)
Enhanced Generation Quality
CFG combines conditional and unconditional predictions to improve generation quality and adherence to class conditioning.
CFG for Class-Conditioned Generation:
CFG Sampling Process:
- Conditional Prediction: ε_θ(x_t, t, c) with class c
- Unconditional Prediction: ε_θ(x_t, t, ∅) with empty class
- Guidance Scale: γ = 5.0 for enhanced quality
- Combined Estimate: Extrapolate beyond conditional prediction
Class-Conditioned Results
Class-Conditioned Training Loss
20 epochs with class and time conditioning
Generated Digits by Class
My Results
Implementation Achievements
Successfully implemented and trained three different diffusion model variants from scratch, demonstrating deep understanding of UNet architectures, DDPM mathematics, and conditional generation techniques.
Model Comparison
Single-Step Denoiser
Fast but limited to specific noise levels
Time-Conditioned DDPM
Better quality through iterative refinement
Class-Conditioned DDPM + CFG
Highest quality with precise control
Training Insights and Challenges
- Architecture Design: Balancing model capacity with training efficiency through proper hidden dimensions
- Conditioning Mechanisms: Understanding additive vs. multiplicative conditioning for different information types
- Variance Scheduling: The critical role of β schedules in training stability and generation quality
- CFG Trade-offs: Balancing quality improvements against diversity reduction
- Training Dynamics: Managing long training times and GPU memory constraints
Extra Credit Explorations
Key Learnings
Deep Learning and Architecture Design
- UNet Mastery: Understanding encoder-decoder architectures with skip connections for multi-scale processing
- Conditioning Strategies: Different approaches for injecting control signals into neural networks
- Training Stability: Learning rate scheduling, gradient clipping, and architecture choices for stable training
- Memory Optimization: Efficient tensor operations and checkpointing for GPU memory management
- Hyperparameter Sensitivity: Understanding the impact of batch size, learning rate, and architecture choices
Diffusion Model Theory
- DDPM Mathematics: Deep understanding of forward and reverse processes, variance schedules
- Noise Scheduling: How β schedules affect training dynamics and generation quality
- Classifier-Free Guidance: Mathematical foundation and practical implementation of CFG
- Conditional Generation: Techniques for controlling generative models with external signals
- Sampling Algorithms: Trade-offs between sampling steps, quality, and diversity
Practical Implementation Skills
- PyTorch Proficiency: Advanced tensor operations, custom loss functions, and training loops
- Debugging Strategies: Identifying and fixing issues in complex generative models
- Experiment Design: Systematic evaluation of different architectures and hyperparameters
- Model Checkpointing: Robust training workflows with resume capabilities
- Visualization Tools: Creating informative plots and comparisons for model evaluation
Impact and Applications
Foundational Understanding
Research Foundation: Deep understanding of generative modeling principles for future research
Architecture Design: Skills in designing and implementing neural network architectures
Training Methodologies: Experience with complex training procedures and optimization strategies
Conditional Generation: Understanding how to control generative models for specific applications
Scalability Principles: Knowledge of how these techniques scale to larger, more complex datasets
State-of-the-Art Connection: Direct experience with the foundations underlying modern AI systems