Generative Adversarial Network with PyTorch

GeoSense ✅
3 min readDec 13, 2021

--

This is a short introduction to use PyTorch. You can find all code of this post here.

Framework of the GAN

Objectives

  • Use DataLoader to load dataset into batches
  • Create Discriminator Network
  • Create Generator Network
  • Understand Training loop to train GAN model

Now we will build (GAN). We will accomplish it in by completing each task in the project:

  1. Setup Google Runtime:

Go to Google Colab (https://colab.research.google.com/) to create a new file.

Then, import libraries:

import torch
torch.manual_seed(42)
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

2. Configurations

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 128
noise_dim = 64
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.99
epochs = 20

3. Load MNIST handwritten dataset

from torchvision import datasets, transforms as T
train_augs = T.Compose([T.RandomRotation((-20,20)), T.ToTensor()])
trainset = datasets.MNIST('MNIST/', download= True, train= True, transform= train_augs)
image, label = trainset[5]
plt.imshow(image.squeeze(), cmap='gray')
print("Total images: ", len(trainset))

4. Load dataset into batches

Import libraries

from torch.utils.data import DataLoader
from torchvision.utils import make_grid

Load data into trainloader

trainloader = DataLoader(trainset,batch_size= batch_size, shuffle= True)
print("Total batches: ", len(trainloader))
dataiter = iter(trainloader)
images, _ = dataiter.next()
print(images.shape)

Function is used to plot some of images from the batch

def show_tensor_images(tensor_img, num_images = 16, size=(1, 28, 28)):
unflat_img = tensor_img.detach().cpu()
img_grid = make_grid(unflat_img[:num_images], nrow=4)
plt.imshow(img_grid.permute(1, 2, 0).squeeze())
plt.show()
= iter(trainloader)
images, _ = dataiter.next()
print(images.shape)
show_tensor_images(images, num_images= 16)

5. Create discriminator network

In case if torch summary is not installed, please run the below:

!pip install torchsummary

Import libraries

from torch import nn
from torchsummary import summary
def get_disc_block(in_channels, out_channels, kernel_size, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels,kernel_size, stride),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.block_1= get_disc_block(1, 16, (3,3),2)
self.block_2 = get_disc_block(16, 32,(5,5),2)
self.block_3 = get_disc_block(32,64,(5,5),2)
self.flatten =nn.Flatten()
self.linear = nn.Linear(in_features=64, out_features= 1)
def forward(self,images):
x1 = self.block_1(images)
x2 = self.block_2(x1)
x3 = self.block_3(x2)
x4 = self.flatten(x3)
x5 = self.linear(x4)
return x5
D = Discriminator()
D.to(device)
summary(D, input_size= (1,28,28))

6. Create generator network

def get_gen_block(in_channels, out_channels, kernel_size, stride, final_block = False):
if final_block == True:
return nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels, kernel_size, stride),
nn.Tanh()
)
return nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels, kernel_size, stride),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator,self).__init__()
self.noise_dim = noise_dim
self.block_1 = get_gen_block(noise_dim,256,(3,3),2)
self.block_2 = get_gen_block(256,128,(4,4),1)
self.block_3 = get_gen_block(128,64,(3,3),2)
self.block_4 = get_gen_block(64,1,(4,4),2, final_block= True)
def forward(self, r_noise_vec):
x = r_noise_vec.view(-1,self.noise_dim,1,1)
x1 = self.block_1(x)
x2 = self.block_2(x1)
x3 = self.block_3(x2)
x4 = self.block_4(x3)
return x4
G =Generator(noise_dim)
G.to(device)
summary(G, input_size=(1,noise_dim))

Replace Random initialised weights to Normal weights

def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.constant_(m.bias, 0)
D = D.apply(weights_init)
G = G.apply(weights_init)

7. Create loss function and load optimisers

def real_loss(disc_pred):
criterion = nn.BCEWithLogitsLoss()
ground_truth = torch.ones_like(disc_pred)
loss = criterion(disc_pred, ground_truth)
return loss
def fake_loss(disc_pred):
criterion = nn.BCEWithLogitsLoss()
ground_truth = torch.zeros_like(disc_pred)
loss = criterion(disc_pred, ground_truth)
return loss
D_opt = torch.optim.Adam(D.parameters(), lr = lr, betas=(beta_1,beta_2))
G_opt = torch.optim.Adam(G.parameters(), lr = lr, betas=(beta_1,beta_2))

To be continued!!!

BECOME a WRITER at MLearning.ai // text-to-video // Detect AI img

--

--

GeoSense ✅
GeoSense ✅

Written by GeoSense ✅

🌏 Remote sensing | 🛰️ Geographic Information Systems (GIS) | ℹ️ https://www.tnmthai.com/medium

No responses yet