In the previous article, we mentioned two main models in the architecture of GAN, the generation model and the discriminant model. To understand them better, in this article we will introduce them in more detail.
I. Generator
Generative Model is a type of model in the field of Machine Learning and Deep Learning, which is used to generate new data with properties similar to the original training data. The goal of the generation model is to learn and understand the structure and characteristics of the training data and then generate new data samples based on the learned knowledge.
This model not only has applications in creating new images, sounds, and texts, but can also be applied in many other fields such as creative arts, design, research and practical applications. . The advancement in the field of biomodeling continues to open up many opportunities and challenges for AI and Deep Learning.
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(100, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 784),nn.Tanh())def forward(self, x):x = self.model(x)return x
II. Discriminator
Discriminator is an important component in the GAN (Generative Adversarial Networks) model. The main task of the Discriminator is to distinguish between real data and fake data generated by the generator network.
When training the GAN model, the Discriminator is provided with real data and dummy data as input. The task of the Discriminator is to make a probabilistic prediction that a data sample is real or fake. The Discriminator tries to maximize the accuracy of the distinction between the two data types, while the Generator tries to generate dummy data that the Discriminator cannot distinguish.
GAN training takes place through iteratively updating the Discriminator and Generator. Discriminator has been updated to improve its ability to distinguish between real data and fake data, while the Generator has been updated to produce better fake data, thereby bypassing Discriminator.
Since these two components work together and compete with each other, the GAN model is able to generate high-quality fake data that is close to the real data. This makes it increasingly difficult for the Discriminator to distinguish between real and fake data, and to create the characteristics and distributions of the original data.
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 1),nn.Sigmoid())def forward(self, x):x = self.model(x)return x
III. Loss Function
The loss function of the liver model is a function that combines the Discriminator target and the Generator target simultaneously.
$\underset{G}{\min}\underset{D}{\max}V(D,G) = E{\textbf{x}\sim p{data}(\textbf{x})}[logD(\textbf{x})] + E_{\textbf{z}\sim p_z(\textbf{z})}[log(1 - D(G(\textbf{z})))] $
Let's analyze this complex loss function together:
- Generator network symbol is , Discriminator network is .
- The symbol is the image generated from the Generator.
- The symbol is the discriminator's prediction value whether the image is real or not.
- The symbol is the value to predict whether the image generated from the Generator is a real image or not.
- The symbol is the expectation, simply understood as taking the average of all data or maximize with being the data in the traning set.
From the loss function above, it can be seen that training Generator and Discriminator are opposite, while tries to maximize loss, tries to minimize loss. The GAN training process ends when the GAN model reaches an equilibrium of the two models, called Nash equilibrium.
# Initialize the Generator network and the Discriminator networkgenerator = Generator()discriminator = Discriminator()# Define loss function and optimizercriterion = nn.BCELoss()optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002)optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002)# Train model GANnum_epochs = 100for epoch in range(num_epochs):for batch_idx, (real_images, _) in enumerate(train_loader):# Determine batch size and prepare training databatch_size = real_images.size(0)real_images = real_images.view(-1, 784)real_labels = torch.ones(batch_size, 1)fake_labels = torch.zeros(batch_size, 1)# Discriminator network trainingdiscriminator.zero_grad()outputs_real = discriminator(real_images)loss_real = criterion(outputs_real, real_labels)z = torch.randn(batch_size, 100)fake_images = generator(z)outputs_fake = discriminator(fake_images)loss_fake = criterion(outputs_fake, fake_labels)loss_discriminator = loss_real + loss_fakeloss_discriminator.backward()optimizer_discriminator.step()# Generator network traininggenerator.zero_grad()z = torch.randn(batch_size, 100)fake_images = generator(z)outputs = discriminator(fake_images)loss_generator = criterion(outputs, real_labels)loss_generator.backward()optimizer_generator.step()
IV. Evaluation and Benchmarking
Generative Adversarial Networks (GAN) model evaluation is an important process to evaluate the performance and quality of the trained model. Here are some common evaluation methods for GAN models:
- Evaluating image quality: For imaging applications, a common assessment method is to use image quality indexes such as SSIM (Structural Similarity Index), PSNR (Peak Signal) -to-Noise Ratio), or FID (Fréchet Inception Distance). These indexes measure the similarity between the generated image and the real image, and the higher it is, the better the quality of the model.
- Assessment of new data generation: A good GAN model should be able to generate new and diverse data. To assess this, methods such as counting the number of different samples generated, or measuring sample diversity through indicators such as the Inception Score can be used.
- Assessment of learning ability and stability: A good GAN model should have fast and stable learning ability. This evaluation can be done by monitoring the evolution of the loss function of the Generator and the Discriminator during training, ensuring that it converges to a stable value and attains a good balance. equals between the two components.
- Evaluate the interaction between Generator and Discriminator: A good GAN model should have an effective interaction between Generator and Discriminator. This assessment can be done by looking at the discriminator's discriminant against spurious patterns generated by the Generator, and ensuring that the Generator is capable of bypassing the Discriminator.
Here is an example of evaluating GAN model by calculating FID (Frechet Inception Distance) measure using Inception V3 model to compute feature vectors from real and dummy images:
import torchimport torch.nn as nnfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderfrom torch.autograd import Variablefrom scipy.linalg import sqrtmfrom sklearn.metrics import pairwise_distancesdef calculate_fid(real_images, fake_images, batch_size, device):# Calculate feature vectors from real and fake imagesinception_model = torchvision.models.inception_v3(pretrained=True, transform_input=False).to(device)inception_model.eval()with torch.no_grad():real_features = []fake_features = []for i in range(0, len(real_images), batch_size):real_batch = real_images[i:i+batch_size].to(device)fake_batch = fake_images[i:i+batch_size].to(device)real_features.append(inception_model(real_batch)[0].view(real_batch.size(0), -1))fake_features.append(inception_model(fake_batch)[0].view(fake_batch.size(0), -1))real_features = torch.cat(real_features, dim=0)fake_features = torch.cat(fake_features, dim=0)# Calculate mean and covariance matrix of feature vectorsreal_mu, real_sigma = torch.mean(real_features, dim=0), torch_cov(real_features, rowvar=False)fake_mu, fake_sigma = torch.mean(fake_features, dim=0), torch_cov(fake_features, rowvar=False)# Calculate FID measurefid = torch.norm(real_mu - fake_mu)**2 + torch.trace(real_sigma + fake_sigma - 2*sqrtm(real_sigma @ fake_sigma))return fid.item()# Load MNIST datatransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,), std=(0.5,))])test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)test_loader = DataLoader(test_dataset, batch_size=100, shuffle=True)# Initialize Model Generator and Discriminatorgenerator = Generator().to(device)discriminator = Discriminator().to(device)# Model evaluationnum_samples = 1000noise = torch.randn(num_samples, 100).to(device)fake_images = generator(noise)fid = calculate_fid(test_dataset.data[:num_samples].float().unsqueeze(1).to(device), fake_images.detach().cpu(), 100, device)print(f"FID: {fid:.4f}")
We will calculate the mean and the covariance matrix of the feature vectors to calculate the FID measure. The smaller the FID measure, the better the quality of the fake image.