Generative Adversarial Network with PyTorch
This is a short introduction to use PyTorch. You can find all code of this post here.
Objectives
- Use DataLoader to load dataset into batches
- Create Discriminator Network
- Create Generator Network
- Understand Training loop to train GAN model
Now we will build (GAN). We will accomplish it in by completing each task in the project:
- Setup Google Runtime:
Go to Google Colab (https://colab.research.google.com/) to create a new file.
Then, import libraries:
import torch
torch.manual_seed(42)
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
2. Configurations
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 128
noise_dim = 64
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.99
epochs = 20
3. Load MNIST handwritten dataset
from torchvision import datasets, transforms as T
train_augs = T.Compose([T.RandomRotation((-20,20)), T.ToTensor()])
trainset = datasets.MNIST('MNIST/', download= True, train= True, transform= train_augs)
image, label = trainset[5]
plt.imshow(image.squeeze(), cmap='gray')
print("Total images: ", len(trainset))
4. Load dataset into batches
Import libraries
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
Load data into trainloader
trainloader = DataLoader(trainset,batch_size= batch_size, shuffle= True)
print("Total batches: ", len(trainloader))
dataiter = iter(trainloader)
images, _ = dataiter.next()
print(images.shape)
Function is used to plot some of images from the batch
def show_tensor_images(tensor_img, num_images = 16, size=(1, 28, 28)):
unflat_img = tensor_img.detach().cpu()
img_grid = make_grid(unflat_img[:num_images], nrow=4)
plt.imshow(img_grid.permute(1, 2, 0).squeeze())
plt.show()
= iter(trainloader)
images, _ = dataiter.next()
print(images.shape)
show_tensor_images(images, num_images= 16)
5. Create discriminator network
In case if torch summary is not installed, please run the below:
!pip install torchsummary
Import libraries
from torch import nn
from torchsummary import summary
def get_disc_block(in_channels, out_channels, kernel_size, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels,kernel_size, stride),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.block_1= get_disc_block(1, 16, (3,3),2)
self.block_2 = get_disc_block(16, 32,(5,5),2)
self.block_3 = get_disc_block(32,64,(5,5),2)
self.flatten =nn.Flatten()
self.linear = nn.Linear(in_features=64, out_features= 1)
def forward(self,images):
x1 = self.block_1(images)
x2 = self.block_2(x1)
x3 = self.block_3(x2)
x4 = self.flatten(x3)
x5 = self.linear(x4)
return x5
D = Discriminator()
D.to(device)
summary(D, input_size= (1,28,28))
6. Create generator network
def get_gen_block(in_channels, out_channels, kernel_size, stride, final_block = False):
if final_block == True:
return nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels, kernel_size, stride),
nn.Tanh()
)
return nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels, kernel_size, stride),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator,self).__init__()
self.noise_dim = noise_dim
self.block_1 = get_gen_block(noise_dim,256,(3,3),2)
self.block_2 = get_gen_block(256,128,(4,4),1)
self.block_3 = get_gen_block(128,64,(3,3),2)
self.block_4 = get_gen_block(64,1,(4,4),2, final_block= True)
def forward(self, r_noise_vec):
x = r_noise_vec.view(-1,self.noise_dim,1,1)
x1 = self.block_1(x)
x2 = self.block_2(x1)
x3 = self.block_3(x2)
x4 = self.block_4(x3)
return x4
G =Generator(noise_dim)
G.to(device)
summary(G, input_size=(1,noise_dim))
Replace Random initialised weights to Normal weights
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.constant_(m.bias, 0)
D = D.apply(weights_init)
G = G.apply(weights_init)
7. Create loss function and load optimisers
def real_loss(disc_pred):
criterion = nn.BCEWithLogitsLoss()
ground_truth = torch.ones_like(disc_pred)
loss = criterion(disc_pred, ground_truth)
return loss
def fake_loss(disc_pred):
criterion = nn.BCEWithLogitsLoss()
ground_truth = torch.zeros_like(disc_pred)
loss = criterion(disc_pred, ground_truth)
return loss
D_opt = torch.optim.Adam(D.parameters(), lr = lr, betas=(beta_1,beta_2))
G_opt = torch.optim.Adam(G.parameters(), lr = lr, betas=(beta_1,beta_2))
To be continued!!!