Shine's dev log
[Pytorch] GAN 구현 및 학습 본문
1. 개요
https://github.com/godeastone/GAN-torch
Pytorch 로 구현한 GAN 전체 코드는 위 git repository에서 확인할 수 있다.
2. GAN
GAN은 2014년 Ian Goodfellow 님에 의해 개발되었다.
GAN 논문에 대한 자세한 정보는 아래 글을 참고하도록 하고, 이번에는 GAN의 전체적인 구조를 대략 살펴보고 코드로 구현하는데 집중을 해도록 하자.
https://ddongwon.tistory.com/117
우선 GAN은 한마디로 Generator와 Discriminator를 adversarial 하게 잘 학습시켜 기존의 데이터셋과 유사한 데이터를 생성하도록 하는 생성모델이다.
Generator는 noise vector 'z' 을 input으로 받아서, 기존의 데이터셋과 유사한 확률 분포를 가지는 데이터 샘플을 output으로 뽑아내는 녀석이다.
Discriminator 는 실제 데이터셋에서 나온 데이터(real) 와, Generator에서 나온 데이터(fake)를 input으로 받아 이 둘을 최대한 real 과 fake로 classification 하는 녀석이다.
즉, Generator는 최대한 그럴듯한 데이터를 만들어내는게 목표고, Discriminator는 최대한 real 과 fake 데이터를 구분하려하는게 목표이다.
Discriminator는 D(x) -> 1(real), D(G(z)) -> 0(fake)로 학습시키고, Generator는 D(G(z)) 의 결과가 최대한 1(real)로 착각하도록 학습시킨다. 흔히 볼 수 있는 [그림 1]의 GAN의 loss 함수는 위 개념을 바탕으로 나온 것이다.
3. 구현
그럼 이제 GAN을 어떻게 구현할 수 있을지 코드를 보며 이해해보자.
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.linear1 = nn.Linear(img_size, hidden_size2)
self.linear2 = nn.Linear(hidden_size2, hidden_size1)
self.linear3 = nn.Linear(hidden_size1, 1)
self.leaky_relu = nn.LeakyReLU(0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.leaky_relu(self.linear1(x))
x = self.leaky_relu(self.linear2(x))
x = self.linear3(x)
x = self.sigmoid(x)
return x
우선 Discriminator는 위와 같이 정의할 수 있다. 일반적인 multi layer neural network로 구성되어 있다.
총 3개의 Linear layer로 구성되어 있는데, 첫번째 layer에서는 MNIST 이미지 사이즈 (1 x 28 x 28 = 784)을 입력받고, 마지막 레이어에서는 classification을 위해 1개의 노드로 정리된다.
각 레이어 사이에는 activation function으로 leaky ReLU 함수가 사용되었으며, 마지막에는 확류로 표현하기 위해 sigmoid 함수가 사용되었다.
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.linear1 = nn.Linear(noise_size, hidden_size1)
self.linear2 = nn.Linear(hidden_size1, hidden_size2)
self.linear3 = nn.Linear(hidden_size2, img_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.linear3(x)
x = self.tanh(x)
return x
Generator는 Discriminator와 반대로 구성되어 있다.
역시 총 3개의 Linear layer로 구성되어 있으며, 입력값으로 noise vector 'z'의 크기만큼의 노드가 사용되고, 마지막 layer에서는 실제 MNIST 데이터의 크기 (1 x 28 x 28 = 784) 개의 노드로 정리된다.
각 layer 사이에는 activation function으로 ReLU 함수가 사용되었으며, 마지막 layer 에는 tanh 함수가 사용되었다.
이제 Generator와 Discriminator를 정의했으니, 본격적으로 학습을 시작해보자.
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)
우선 학습에 사용될 Loss function으로는 BCELoss를 사용하였으며, 국민 optimizor인 Adam을 사용하였다.
이때, genrator와 discriminator는 서로 따로따로 학습되므로 각각 optimizer를 구분지어 정의해주어야 한다.
for epoch in range(num_epoch):
for i, (images, label) in enumerate(data_loader):
# 라벨을 만들어 줍니다. 1 for real, 0 for fake
real_label =
torch.full((batch_size, 1), 1, dtype=torch.float32).to(device)
fake_label =
torch.full((batch_size, 1), 0, dtype=torch.float32).to(device)
# MNIST dataset의 데이터를 flatten 하게 reshape 해줍니다.
real_images = images.reshape(batch_size, -1).to(device)
이제 for문을 통해 각 epoch 마다 학습을 시켜주게 된다.
학습을 위해 [batch size, 1] 크기의 모두 1로 구성된 real label 의 tensor와 모두 0으로 구성된 fake label의 tensor를 만들어 주었다.
또한 [1 x 28 x 28] 크기의 MNIST 데이터를 [batch size, 784] 의 크기로 flatten 해주는 과정을 거쳤다.
# +---------------------+
# | train Generator |
# +---------------------+
# Initialize grad
g_optimizer.zero_grad()
d_optimizer.zero_grad()
# fake image를 generator와 noize vector 'z' 를 통해 만들어주기
z = torch.randn(batch_size, noise_size).to(device)
fake_images = generator(z)
# loss function에 fake image와 real label을 넘겨주기
# 만약 generator가 discriminator를 속이면, g_loss가 줄어든다.
g_loss = criterion(discriminator(fake_images), real_label)
# backpropagation를 통해 generator 학습
g_loss.backward()
g_optimizer.step()
우선 Generator를 학습시켜주자. Discriminator 를 먼저 학습시키든 Generator를 먼저 학습시키든 상관없지만, 중요한 것은 각자 따로따로 학습시켜줘야 한다는 점이다.
우선 noise vector 'z' 를 torch.randn 함수를 통해 랜덤한 값으로 채워준다.
이후 앞서 선언한 generator에 z를 넣어줌으로써 [28 x 28 = 784] 크기의 이미지 데이터를 생성하게 된다. 즉 G(z)는 Generator가 생성한 batch size 개수만큼의 이미지가 된다.
앞서 2장에서 설명했듯이 Generator는 D(G(z))의 성능을 낮추는 방향으로 학습된다.
따라서 loss 함수에 D(G(z))와 real label을 함께 넣어준다.
이렇게 할 경우, Discriminator가 제대로 판단을 할 경우(fake라 판단) Generator는 올바른 방향으로 데이터를 생성하지 못했다고 생각하게 되고, Discriminator 가 제대로 판단하지 못할 경우(real로 판단) Generator는 올바른 방향으로 데이터를 생성했다고 생각하게 된다.
이런 과정을 통해 Generator의 성능이 높아지는 방향으로 학습이 진행되게 된다.
# +---------------------+
# | train Discriminator |
# +---------------------+
# Initialize grad
d_optimizer.zero_grad()
g_optimizer.zero_grad()
# generator와 noise vector 'z'로 fake image 생성
z = torch.randn(batch_size, noise_size).to(device)
fake_images = generator(z)
# fake image와 fake label, real image와 real label을 넘겨 loss 계산
fake_loss = criterion(discriminator(fake_images), fake_label)
real_loss = criterion(discriminator(real_images), real_label)
d_loss = (fake_loss + real_loss) / 2
# backpropagation을 통해 discriminator 학습
# 이 부분에서는 generator는 학습시키지 않음
d_loss.backward()
d_optimizer.step()
다음으로 Discriminator를 학습시켜주자.
우선 앞서 했던거 같이 z를 generator에 통과시켜 fake image를 만들어준다.
D(G(z)) 값을 loss function에 fake label과 함께 넣어 fake loss를 구해주고, D(x) 값을 loss function에 real label과 함게 넣어 real loss를 구해준다.
이렇게 구한 두 fake / real loss를 평균내서 전체 discriminator 의 loss값을 구해준다.
이렇게 하면 Discriminator가 제대로 fake 와 real 이미지를 판단할 수 있는 방향으로 학습이 진행되게 된다.
이렇게 되면 한 epoch 의 학습이 끝나게 된다.
처음에는 generator가 터무니 없는 데이터를 생성하기 때문에 discriminator가 어렵지 않게 이를 분류할 수 있다. 그래서 d_loss 값은 작게 나오고, g_loss값은 크게 나온다.
하지만 학습이 진행될수록 d_loss값은 커지고 g_loss값은 점점 작아지는 것을 확인할 수 있다. 이는 Discriminator가 점점 진짜와 가짜 이미지를 판단하기 어려워지고, Generator가 점점 진짜같은 가짜 이미지를 생성해낸다는 뜻이다.
즉, GAN의 기존 의도와 딱 맞게 학습이 되어간다는 뜻이다.
실제 코드를 돌려보면 알 수 있겠지만, 대략 200 epoch 정도만 학습해줘도 [그림 2]와 같이 꽤나 그럴듯한 데이터가 생성되는 것을 확인할 수 있다.
만약 Discriminator와 Generator를 조금 더 정교하게 설계하거나 batch normalization 등의 다양한 테크닉을 적용한다면 보다 정교하고 그럴듯한 데이터를 생성해낼 수 있을 것이다.
위 내용은 공부하며 정리한 것으로, 오류가 있을 수 있습니다.
'머신러닝' 카테고리의 다른 글
[Pytorch] Conditional GAN 구현 및 학습 (CGAN) (4) | 2022.03.19 |
---|---|
[ML] Deepfake Detection 성능 분석 (8) | 2022.02.26 |
KL divergence와 JSD의 개념 (feat. cross entropy) (1) | 2022.02.19 |
[Recommender system] 영화 추천 시스템 (0) | 2022.01.23 |
PCA (Principal Component Analysis) : 주성분 분석 이란? (31) | 2021.12.21 |