Lesson 7: Generative Models - VAEs and GANs
Introduction
Welcome to Lesson 7! Today, we're diving into the fascinating world of generative models, specifically Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs). These powerful models can create new data, like images or text, that look just like real data. It's like teaching a computer to be creative!
By the end of this lesson, you'll understand how VAEs and GANs work, and you'll create your own simple image generation model. Don't worry if these concepts seem complex at first - we'll break them down step by step and use simple analogies to explain them.
1. Variational Autoencoders (VAEs)
Imagine you're an artist trying to paint landscapes. Instead of memorizing every tree and cloud, you learn to capture the essence of landscapes - things like colors, shapes, and compositions. This 'essence' is what we call the latent space in VAEs.
A VAE has two main parts:
- An encoder that turns input data (like images) into this 'essence' (latent space)
- A decoder that turns this 'essence' back into data (like generating new images)
Here's a simple implementation of a VAE using PyTorch:
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc2 = nn.Linear(latent_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h = F.relu(self.fc1(x))
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = F.relu(self.fc2(z))
return torch.sigmoid(self.fc3(h))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# Instantiate and use the VAE
vae = VAE(784, 400, 20) # For MNIST: 28x28=784 input dim, 400 hidden dim, 20 latent dim
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
# Training loop (simplified)
for epoch in range(100):
for batch_idx, (data, _) in enumerate(train_loader):
optimizer.zero_grad()
recon_batch, mu, logvar = vae(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
optimizer.step()
This code defines a VAE that can be trained on image data (like the MNIST dataset of handwritten digits). The encoder compresses the input into a latent space, and the decoder generates new images from this space.
Interactive Visualization: VAE Latent Space
Let's visualize the latent space of a VAE trained on handwritten digits. Each point represents a position in the latent space, and nearby points should generate similar images:
In a real VAE, you could click on any point in this space to generate a new image. The smooth transitions between points allow for interesting interpolations between different digits.
2. Generative Adversarial Networks (GANs)
Now, let's talk about GANs. Imagine a game between two players: an art forger (the Generator) and an art detective (the Discriminator). The forger tries to create fake paintings that look real, while the detective tries to spot the fakes. As they play, both get better at their jobs.
In a GAN:
- The Generator creates fake data (like images) from random noise
- The Discriminator tries to distinguish between real and fake data
- They train together, each getting better, until the Generator creates data that looks real
Here's a simple implementation of a GAN using PyTorch:
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# Instantiate and use the GAN
latent_dim = 100
img_shape = (1, 28, 28) # For MNIST
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)
# Loss function and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Training loop (simplified)
for epoch in range(200):
for i, (imgs, _) in enumerate(dataloader):
# Train Discriminator
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), torch.ones(imgs.size(0), 1))
z = torch.randn(imgs.size(0), latent_dim)
fake_imgs = generator(z)
fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), torch.zeros(imgs.size(0), 1))
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
g_loss = adversarial_loss(discriminator(fake_imgs), torch.ones(imgs.size(0), 1))
g_loss.backward()
optimizer_G.step()
This code defines a GAN that can generate images. The Generator creates images from random noise, and the Discriminator tries to tell if an image is real or generated.
Interactive Demo: GAN Image Generation
Let's simulate generating images with a GAN. Click the button to generate a new 'handwritten' digit:
This demo uses actual MNIST digits to simulate what a GAN might generate. In a real GAN, you would see completely new, unique digits that look similar to these training examples.
GAN Training Visualization
Let's visualize how the losses for the Generator and Discriminator change during training:
As training progresses, you can see how the losses for both the Generator and Discriminator change. Ideally, they should converge to a balance where the Generator is creating realistic images and the Discriminator is unsure whether they're real or fake.
3. Basic Image Generation
Now that we understand VAEs and GANs, let's talk about how they're used for image generation. Both models can create new, realistic-looking images, but they do so in different ways:
- VAEs generate images by sampling from the latent space and decoding
- GANs generate images by feeding random noise through the Generator
Here's a simple example of how you might generate an image using a trained GAN:
# Assuming we have a trained Generator
generator.eval()
# Generate a random noise vector
latent_dim = 100
z = torch.randn(1, latent_dim)
# Generate an image
with torch.no_grad():
generated_image = generator(z)
# Convert the image tensor to a PIL Image and save it
from torchvision.utils import save_image
save_image(generated_image, "generated_image.png", normalize=True)
This code generates a single image from a random noise vector using a trained GAN Generator. The image is then saved to a file.
Challenge: Create a Character Generator
Now it's your turn! Try to build a simple character generator using either a VAE or a GAN. Here are some ideas to get you started:
- Use a dataset of simple cartoon characters or emojis for training
- Modify the architecture to work with color images
- Add a condition to the generator (like character type or emotion) to control the output
- Create an interpolation between two characters in the latent space (for VAE)
- Experiment with different loss functions to improve image quality
This challenge will help you apply what you've learned and gain hands-on experience with generative models. Don't worry if your first results don't look perfect - generative models often require a lot of tuning to get right!