Unrolled GAN이 풀고자 하는 문제
오늘 소개할 논문에서 풀고자 하는 문제는 GAN의 고질적 문제 중 하나인 불안정한 학습의 안정화(stabilize)입니다. GAN의 불안정성은 크게 세 가지 문제로 정리되는데요:
GAN의 불안정성
- Mode collapsing or dropping (generator)
- Generator and discriminator oscillating during training
- No learning when the power between generator and discriminator is unbalanced.
세 가지 문제가 사실 다 연결되어 있습니다. 게다가 Kullback Leibler (KL)와 같이 mode-covering divergence를 사용하는 경우에도 생성 모델이 수렴하고 나서도 전체 분포를 포괄(cover)하지 못하는 경우가 있습니다.
말로만 불안정하다라고 얘기하는 것보다 예시를 보는게 감이 팍 오지요. 아래 그림은 그 중 mode collapsing이라는 문제를 보여줍니다.
Unrolled GAN(위) vs. standard GAN(아래)
가장 오른쪽 열에 target이 우리가 찾고자 하는 data distribution이라고 생각하시면 됩니다. 2D Gaussian 분포들을 여러 개 섞은 형태의 dataset이네요. 왼쪽에서 오른쪽으로 갈 수록 학습이 진행되는 것이고(step 0 ~step 25k) 각각의 step에서 보여주는 그림은 해당 step에서 학습한 generator distribution이라고 생각하시면 됩니다.
위 줄의 unrolled GAN은 점차 여러 mode를 가진 분포를 찾아가는 반면 아래 줄의 standard GAN은 각각의 mode들을 번갈아 돌아가면서 전체적인 데이터 분포를 찾지 못하고 한 번에 하나의 mode에만 weight를 주는 것을 보실 수 있습니다. 즉, mode collapsing이 일어나는 것이죠.
좀 더 구체적인 예시로 이해를 해보겠습니다.
MNIST 데이터로 학습한 결과 비교
Mode collapse 문제가 생기는 이유
지난 글에서 DCGAN에 대해 소개하면서도 같은 내용(stability)에 대해 말씀을 드린 적이 있지요. GAN이 불안정한 이유는 태생적으로 minimax problem을 풀어야하기 때문이기도 하지만, 실제 학습을 할 때는 함수 공간에서 학습을 하는 것이 아니라 Neural Network를 사용하면서 생기는 이론적 가정과의 괴리 때문에도 수렴한다는 보장을 할 수 없게 됩니다.
* 이에 대한 가벼운 소개로는 DCGAN 글을 보시고 좀 더 자세한 이론적 설명과 증명은 GAN 글(1), (2)를 보시면 되겠습니다.
물론 실제 상황에서는(empirically) Neural Network를 사용하는 방식이 꽤나 well-behaved하고 있고, 특히 DCGAN과 같이 네트워크의 구조를 조심스럽게 설계하면 이 문제를 상당 부분 해결할 수 있다는 것이 확인되었습니다. 그러나 여전히 실제에서 사용할 때 hyperparameters를 조정할 수 있는 범위가 좁은 편이고 학습에 실패하는 경우가 종종 있었죠.
사실 우리가 매 step마다 optimal discriminator $D^*$를 계산할 수 있다면야 이런 문제가 생기지 않습니다.....만 Neural Network로 수렴이 될 때까지 계산을 매우 여러 번 오래 해야하기 때문에 현실적으로 불가능하지요(infeasible computational cost).
$$D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_G(x)}.$$
* 이 문제는 상당히 잘 알려져 있고 풀기가 쉽지 않기 때문에 NIPS 2016 Tutorial에서도 짧지만 한 단락을 할애하여 mode collapse 문제가 일어나는 이유에 대해 언급을 하고 있습니다. 상당히 직관적으로 이유를 잘 설명해주고 있기에 여기에 같이 소개해보도록 하겠습니다.
원래 우리가 풀고자하는 GAN 문제는 다음과 같은 minimax problem입니다:
$$G^* = \min_G \max_D V(G,D).$$
그런데 우리가 실제 학습을 할 때는 $G$와 $D$에 대한 update를 번갈아가며 해주기 때문에 Neural network의 입장에서는 이러한 minimax problem과 아래와 같은 maximin problem이 구별이 되지 않습니다:
$$G^* = \max_D \min_G V(G,D).$$
문제는 위와 같은 maximin problem의 경우에서 생깁니다. 수식의 안 쪽부터 살펴보겠습니다. $G$에 대한 minimization 문제가 먼저 있기 때문에 generator의 입장에서는 현재 고정되어있는 discriminator (non-optimal)가 가장 헷갈려 할 수 있는 sample 하나만 즉, value $V$를 가장 최소화할 수 있는 mode 하나만을 내보내면 그만입니다.
또 다른 성격의 문제로 GAN으로 만든 sample image의 질을 정량적으로 측정할 수 있는 방법이 마땅치 않다는 것이 있습니다. 좀 더 논문의 글을 인용하며 말하자면 GAN의 training loss를 계산하는 것이 intractable하기 때문입니다.
* 이 부분은 다음에 리뷰할 Wasserstein GAN(WGAN)에서 주요 포인트 중 하나가 되지요.
따라서 사람의 주관적인 판단을 사용하여 측정하는 방법을 사용하는데 이런 방식은 mode collapse 문제에 그리 도움을 주지 못합니다. 사람이 볼 때는 하나의 mode로부터 뽑은 sample이 그럴듯하면 이미지의 질이 좋다고 생각하기 쉽고, generator가 모든 mode들을 cover하지 못하더라도 몇 가지 한정된 mode들을 찾아서 이로부터 적당히 이미지 종류가 섞여 생성되는 경우, 딱히 이 sample들이 전체가 아닌 한정된 mode에서 나오는 것인지 알기가 쉽지 않기 때문입니다.
Unrolled GAN이란?
그럼 이 문제를 unrolled GAN에서는 어떻게 풀고자 하는 것일까요? 힌트는 이름에 있습니다. 생각보다는 단순한 아이디어인데요 수식으로 설명이 들어가기 전에 큰 줄기만 말씀드리자면 generator를 update할 때 이 update에 대해 discriminator가 어떻게 반응하는지에 대한 정보를 추가로 더 주겠다는 얘기입니다.
어떻게 그런 정보를 주는가?도 매우 쉽습니다. 논문에서 나오는 아래 도식을 보시면 설명이 되어있는데요
Unrolled GAN 도식 (unrolling three steps 예시)
그림이 직관적이지 않아 마음에 들지는 않습니다만 제가 다 읽고 설명하려고 해도 이 그림 이상으로 더 낫게 그리기는 어렵다는 생각이 드네요. 세 번 unrolling한 GAN에 대해 설명을 하고 있고 왼쪽으로 돌아오는 화살표들이 모두 backpropagation을 의미합니다.
요약해보자면 generator를 update할 때, 앞으로 $k$ step에 대해서(도식에서는 $k=3$) discriminator의 loss가 어떻게 반응하는지를 계산하고 일반적인 backpropagation과 같이 이에 대한 gradient들을 계산하여 반영해주겠다는 전략입니다.
단, 실제로 discriminator가 $k$번 update되지는 않는다는 점(붉은 화살표는 하나뿐)에서 기존의 GAN과는 다르다고 할 수 있습니다.
* 사실 처음 논문을 읽었을 때 쉽게 드는 생각이 "기존의 GAN에서 G를 고정하고 D를 k번 update해준 다음 다시 G를 update하는 방식과 다른 것이 뭐지?"라는 것이죠. 이 부분이 아직은 좀 두루뭉술하겠지만 차차 수식과 함께 설명하며 고민해보면 다르다는 것을 아실 수 있을 것이라 생각됩니다.
Unrolling GANs
수식과 함께 좀 더 깊이 살펴보도록 하겠습니다. 기존의 GAN 문제가 다음과 같은 $f$ 함수를 value로 하여 parameter를 학습하는 방식을 사용할 때:
$$f(\theta_G,\theta_D) = \mathbb{E}_{x\sim p_{data}~(x)}[log D(x;\theta_D)] + \mathbb{E}_{z\sim p_x(z)}[log(1-D(G(z;\theta_G);\theta_D))],$$
unrolled GAN에서는 $f_K(\theta_G, \theta_D) $라는 surrogate objective function을 도입하여 이상적인 objective function인 $f(\theta_G,\theta_D^*(\theta_G))$를 좀 더 근사할 수 있도록 합니다. 여기서
$$f_K(\theta_G, \theta_D) = f(\theta_G,\theta_D^K(\theta_G,\theta_D))$$
이며, learning rate $\eta^k$에 대해 $\theta_D^*$의 local optimum을 구하는 것은
$$\begin{align} \theta_D^0 &= \theta_D \\ \theta_D^{k+1} &= \theta_D^k+\eta^k\frac{df(\theta_G,\theta_D^k)}{d\theta_D^k} \\
\theta_D^* &= \lim_{k\rightarrow \infty}\theta_D^k\end{align}$$
와 같이 나타낼 수 있습니다. 즉, $k\rightarrow \infty$가 되면 optimal discriminator를 얻을 수 있으므로 true generator objective function $f(\theta_G,\theta_D^*(G))$에 대해 푸는 문제가 되고, $k= 0$이면 기존의 GAN objective와 정확히 일치하게 됩니다.
따라서 unrolled GAN은 standard GAN의 문제점과 true generator loss를 구하기 위해 필요한 computational cost 사이에서 저울질(trade-off) 하며 좀 더 나은 결과를 얻고자 하는 것입니다.
이제 이 surrogate loss를 이용하여 generator와 discriminator의 parameter들을 update해보겠습니다.
$$\begin{align} \theta_G &\leftarrow \theta_G - \eta\frac{df_K(\theta_G,\theta_D)}{d\theta_G} \\ \theta_D &\leftarrow \theta_D + \eta\frac{df(\theta_G,\theta_D)}{d\theta_D}. \end{align}$$
일단 이 정도면 unrolled GAN에서 큰 맥락은 대부분 살펴 본 것 같습니다. 다음 편에 이어서 이론적 부분에서 분석한 unrolled GAN 장점을 좀 더 살펴보고 실험 결과와 Appendix의 내용들까지 정리해보도록 하겠습니다.
* 그나저나 아무래도 좀 더 자세히 알려면 unrolled GAN도 코드를 보고 따라 짜보기라도 해야겠네요...가능하다면 앞서 소개한 내용들도 tensorflow 등으로 코드를 짜고 해설과 함께 올려보도록 하겠습니다.
** 항상 그렇듯이 오탈자나 잘못된 점 궁금한 점이 있으시면 블로그의 댓글로 달아주세요. 가능한 빠른 시간 안에 확인하고 수정 혹은 답변을 하겠습니다. 이런 댓글과 함께 google plus 혹은 facebook을 통한 공유나 좋아요를 눌러주시는 것이 매우 큰 보람이 됩니다.
다음 읽을거리
- 초짜 대학원생의 입장에서 이해하는 Unrolled Generative Adversarial Networks (2)
- 초짜 대학원생 입장에서 이해하는 Generative Adversarial Nets (1)
- 초짜 대학원생 입장에서 이해하는 Generative Adversarial Nets (2)
참고문헌:
[1] Unrolled Generative Adversarial Networks - L. Metz et al. 2016
[2] Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks - Alec Radford et al. 2016
[3] Generative Adversarial Nets - Ian Goodfellow et al. 2014 논문,
[4] NIPS 2016 Tutorial: Generative Adversarial Networks
안녕하세요. 아까 네이버 미팅에서 뵈었었는데 논문 번역해서 공유하신다길래 혹시나 했는데 맞네요ㅎㅎ Unrolled GAN 포스팅 잘 보았습니다 !
답글삭제안녕하세요 반갑습니다ㅎㅎ 댓글 감사해요 보람이 있군요 ㅎ
삭제가볍게 전체를 흝고 있습니다.
답글삭제Unrolled GAN을 보고 바로 떠오르는것은 바둑의 수읽기 같습니다.
세 수 정도 먼저 그려보고 한 수를 두기. :-)
맞습니다. 그렇게 설명하는게 쉬운 예시네요. 저도 다음부터 이 예시를 애용해야겠습니다.
삭제