GAN Tutorial
Arvin Liu @ AISS 2020
DL Review
General Backbone in DL
Data
(Input)
Model
(Function)
Loss
Function
Optimizer
Output
(Output)
Expected
(Answer)
- We'll not mention model &
optimizer.
GAN Review
GAN's GOAL
Generate a new fake data
i.e., pictures, music, voice, etc.
Generative Adversarial Network
Data
(Input)
Model
(Function)
Output
(Output)
Expected
(Answer)
???
hyper-params
real data
L
---
???
random vector (z)
First ???
Model /
Generator
Skill Point -G-> Character
???
image resource: 李宏毅老師的投影片
Given skill points to generate something.
Generative Adversarial Network
Data
(Input)
Model
(Function)
Output
(Output)
Expected
(Answer)
random
hyper-params
real data
---
Second ??? - How 2 Measure?
Generative Adversarial Network
???
Model
Model /
Discriminator
How real is the data
Data
(Input)
Score
(Output)
L
???
Generative Adversarial Network
Model /
Discriminator
(How real is the data,
[0, 1])
Data
(Input)
Score
(Output)
Model /
Discriminator
0.01
0.5
0.9
GAN Framework
Generator
(G)
Generated
Data
Real
Data
random vector (z)
Discriminator
(D)
Score
Loss + Optimizer(G)
Loss+
Optimizer(D)
GAN Intuition (D)
Generator
(G)
Generated
Data
Real
Data
random vector (z)
Discriminator
(D)
Score
Loss+
Optimizer(G)
Loss+
Optimizer(D)
D's target: beat G
GAN Intuition (G)
Generator
(G)
Generated
Data
Real
Data
random vector (z)
Discrimintaor
(D)
Score
Loss+
Optimizer(G)
Loss+
Optimizer(D)
G's target: fool D
GAN vs RL
RL /GAN by DL framework
Data
(Input)
Model
(Function)
Output
(Output)
Environment
Action
Reward
Loss Function
Agent
Optimizer
Random vector
Generator
Fake Data
Discriminator's score
- discriminator's score in GAN ~= environment's reward in RL.
- GAN need to update discriminator.
Hand-by-hand GAN
(If cannot open successfully, use incognito mode to open it.)
Task
Generate MNIST Data
image shape: (28, 28) gray scale
GAN Framework
Generator
(G)
Generated
Data
Real
Data
random vector (z)
Discrimintaor
(D)
Score
Loss+Opt(G)
Loss+Opt(D)
Data Pre-process (cell 1)
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))])
mnist = MNIST(root='./data/', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=mnist, batch_size=32, shuffle=True)
- torchvision package has MNIST dataset.
- Use dataloader to get batched data.
- Use transform to cast PIL images to tensor
Generator & Discriminator (cell 2-1)
import torch
import torch.nn as nn
# Discriminator
D = nn.Sequential(
nn.Linear(28*28, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()).cuda()
# Generator
G = nn.Sequential(
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 28*28),
nn.Tanh()).cuda()
ReLU
Leaky ReLU
tanh
sigmoid
value in (-1, 1)
value in (0, 1)
Note: You can try CNN!
Optimizer (cell 2-2)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
- Use BCELoss as loss function.
- Don't ask me why I use Adam optimizer.
Start Training! (cell 3-1)
for epoch in range(100):
for images, _ in data_loader:
batch_size = images.shape[0]
images = images.view(batch_size, 784).cuda()
real_labels = torch.ones(batch_size, 1).cuda()
fake_labels = torch.zeros(batch_size, 1).cuda()
train_discriminator()
train_generator()
- tensor.view([the shape you wish])
- real_labels: ones_tensor, fake_labels: zeros_tensor.
These two are later used as input for BCE Loss. - train your discriminator.
- train your generator.
- repeat all of steps above.
Start Training! (G) (cell 3-2)
def train_generator():
z = torch.randn(batch_size, 64).cuda()
fake_image = G(z)
g_loss = criterion(D(fake_image), real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
Generator
(G)
Generated
Data
random vector (z)
Discrimintaor
(D)
Score
Optimizer(G)
-
torch.randn([shape you want]) : generate random
matrix from normal distribution(0, 1).
loss
Start Training! (D) (cell 3-3)
def train_discriminator():
z = torch.randn(batch_size, 64).cuda()
d_loss_fake = criterion(D(G(z)), fake_labels)
d_loss_real = criterion(D(images), real_labels)
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
Generated
Data
Real
Data
Discrimintaor
(D)
Score
Optimizer(D)
loss
Result
More?
I think it's useless...
conditional GAN
Generator
(G)
Generated
Data
random vector (z)
condition
(c)
Voice
Generated Data
Condition
安安
TTS (Text2Speech)
Style Transfer
Q&A?
GAN @ AISS 2020
By Arvin Liu
GAN @ AISS 2020
GAN @ AISS 2020
- 1,267