Conditional Generative Adversarial Nets
Introduction
기존의 Generative Adversarial Nets
생성 모델을 학습하기 위해서, 데이터의 확률적 계산의 어려움을 대체하는 GAN. 이 모델은 Markov chain이나 별다른 추측 필요없이 오직 Back-Propagation으로만 학습이 가능하다. (더 자세한 내용은 이전 글 참조)
생성되는 데이터를 조절할 수 있을까?
하지만 기존의 모델은 Unconditional 생성 모델로, 데이터가 생성되는 종류를 제어할 방법이 없다. 하지만 이것이 가능하다면 성능이 향상되지 않을까? 이를 조건 설정을 통해 데이터 생성 과정을 제어하고자 하는 것이 바로 이 논문에서 다룰 Conditional Generative Adversarial Nets이다.
Conditional Generative Adversarial Nets
GAN vs CGAN
두 모델의 차이점은 바로 레이블이라는 추가적인 정보가 추가된 점이다. 실제로 들어가는 형태는 합쳐진 형태로 들어가서 입력이 단지 조건부로 바뀐 점을 제외하고는 기존 GAN과 구조나 학습방법은 크게 다를것이 없다. 이때 추가적인 정보로 간주되는 어떠한 종류의 정보라도 사용될 수 있다.
CGAN의 Objective Function
마찬가지로 기존의 GAN과 크게 다른점이 없다. 들어가는 입력만 조건부로 바뀌었다.
Implementation
MNIST 데이터 셋을 예시로 들자. 만약 숫자 7을 생성하고 싶다면, Class의 Label을 활용하여 [0000000100]를 추가하여 사용한다. 이를 다음과 같이 구현해 볼 수 있다.
class Generator(nn.Module):
def __init__(self, params):
super().__init__()
self.num_classes = params['num_classes']
self.latent_dim = params['latent_space']
self.input_size = params['input_size']
# Label embedding matrix
self.label_emb = nn.Embedding(self.num_classes, self.num_classes)
self.model = nn.Sequential(
*self.block(self.latent_dim + self.num_classes, 128, normalize=False),
*self.block(128,256),
*self.block(256,512),
*self.block(512,1024),
nn.Linear(1024, int(np.prod(self.input_size))),
nn.Tanh()
)
def block(self, in_channels, out_channels, normalize=True):
layers = []
layers.append(nn.Linear(in_channels, out_channels)) # fc layer
if normalize:
layers.append(nn.BatchNorm1d(out_channels, 0.8)) # Batch Normalization
layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
return layers
def forward(self, noise, labels):
# Concatenate label embedding and image to produce input
gen_input = torch.cat((self.label_emb(labels), noise), -1)
img = self.model(gen_input)
img = img.view(img.size(0), *self.input_size)
return img
먼저 클래스의 라벨을 활용해서 Label Embedding Matrix를 생성하고, 이를 입력시키기 전에 기존의 노이즈와 합쳐서 사용하는 것으로 간단하게 구현가능하다. 기존의 GAN의 경우 생성되는 데이터를 조절할 수 없어서, 랜덤하게 나오는 반면, CGAN의 경우 단지 입력에 조건을 붙이는 방법을 통해서 결과물을 조절이 가능해지는 것을 확인해 볼 수 있다.
전체적인 코드는 아래의 Github에서 확인하자.
Conclusion
결과적으로 CGAN은 단지 기존의 GAN의 입력값에 Condition을 추가하는 것으로 결과물을 조절하는 효과를 얻을 수 있다. 예시로 사용한 MNIST의 경우 클래스의 라벨을 사용해서 조건을 추가했지만, 이러한 것이 아니라 이미지나 다른 무엇이든 분별이 될 수 있는 어떠한 형태라면 사용이 가능하다는 점에서 활용도가 크다고 생각한다.
기존의 GAN과 모델 구조나 학습 과정은 거의 다르지 않으며, Generator는 Condition을 활용해 무엇을 출력할 것인지에 대한 지표로 삼는다는 점과 Discriminator는 Condition을 활용해 무엇이 올 것인가 기대하며 예측을 진행한다는 점의 차이일 뿐이다.
의견과 질문은 언제나 감사합니다.
'Paper > Generative Model' 카테고리의 다른 글
Generative Adversarial Networks (0) | 2022.03.07 |
---|