2017년 6월 18일 일요일

초짜 대학원생의 입장에서 이해하는 f-GAN (1)

* f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization, Sebastian Nowozin et al. 2016을 바탕으로 한 리뷰

오랜만에 다시 GAN으로 돌아와서 포스팅을 해봅니다. 요즘은 너무 GAN쪽 연구가 많이 그리고 빨리 발전하고 있어서 글을 리뷰하는 것이 민망할 지경입니다. 심지어 응용뿐 아니라 이론적인 부분에서도 꽤나 많은 부분이 연구되고 바뀌고 있기에 따라가기가 벅찹니다. 

이 논문만 해도 작년 6월 arXiv에 올라왔고 NIPS 2016 워크샵에서 발표하는 것을 직접 들었었는데 이게 한~~~~~~참 전 논문으로 취급되고 있는 것을 보면 머신러닝 분야는 참 짧은 시간에 격세지감 느끼기 좋은 분야인 것 같습니다...orz

WGAN 이전에는 f-GAN 논문이 다른 연구들을 모조리 잡아먹는? 느낌을 주었으나 이제는 이 말도 무색할 지경입니다. 그래도 f-GAN 논문은 여러모로 GAN 분야에서 매우 중요한 milestone이기 때문에 사실 다른 것보다 더 먼저 리뷰했어야 하는데 여러 이유로 블로그 글로 정리하지 못해서 항상 마음에 걸렸습니다.

아무튼 이 논문이 마치 어제 나온 것인양! 한 문장으로 요약해보자면 다음과 같습니다:

"앞으로 너희가 어떤 divergence를 사용해서 GAN을 구성하든지 모두 다 나의 하위 호환(special case)일 따름이다."

...패기롭다....ㅋㅋ

이 때만 해도 vanilla GAN objective가 JSD를 이용해서 $p_{model}$과 $p_{data}$의 차이를 계산하는 것이란 분석을 바탕으로, 조금씩 네트워크의 구조를 바꾸거나 regularization term을 새로 붙여본다거나 하는 등 약간의 변화만 주고 있던 때라는 점을 생각해보시면...

이렇게 근간을 쥐고 흔드는 내용을 NIPS 2016 워크샵에서 얘기해줄 때 이제 막 혹시 다른 divergence는 어떨까 하면서 새로운 GAN 형태를 연구해보던 사람들은 어떤 느낌이었을까요? 저는 그 때서야 갓 GAN을 접하던 시기(사실 처음 그곳에서 GAN이란 게 있단 것을 알았습니다ㅎㅎ)라 잘 못 느꼈습니다만 기존에 연구를 하던 사람들에게는 꽤나 충격적이었을 것으로 보입니다.

아무튼 오늘은 이 f-GAN에 대해 차근차근 알아보도록 하겠습니다.


Contribution


이 논문이 주장하는 contribution은 다음과 같이 세 가지 정도로 볼 수 있습니다:
  1. GAN objective를 임의의 f-divergence에 대해 일반화하였다. 
  2. 이런 일반화된 GAN objective를 풀 수 있는 GAN algorithm을 제시하되 vanilla GAN algorithm 보다 단순화 하였고 이에 대한 local convergence를 증명하였다. 
  3. 여러 divergence에 대해 결과를 확인하고 보여주었다.
여기서 1번에 f-divergence라는 것은 어떤 조건을 만족하는 임의의 함수 f에 대해 divergence들을 무수히 만들 수 있는데, 이 모든 divergence를 껴넣을 수 있는 일반적인 GAN objective를 도출했다는 뜻입니다. 

이게 바로 제가 위에서 애기했던 "하위 호환"이란 부분에 해당하는 내용입니다. 즉, JSD를 사용하는 vanilla GAN도 어떤 특정한 f에 대해 만들어진 divergence를 사용한 f-GAN의 special case라는 얘기죠. 

게다가 이런 일반적인 objective를 푸는 알고리즘에 대해 local convergence긴 하지만 convergence도 증명을 해서 보여줬습니다. 결국 앞으로 divergence 계열로 누군가 objective를 조금 바꿔서 연구를 한들 결국 "너넨 내 연구를 reproduce하는 정도가 될 뿐이야"란 것이죠...(거참 너무하네)

사실 위와 같이 이들이 주장하는 contribution 외에도, 저는 자자들이 GAN을 바라보는 또 다른 관점을 볼 수 있어서 매우 흥미로웠는데요 이에 대해 간단히 소개하며 본격적으로 내용을 살펴보도록 하겠습니다. 

Learning Probabilistic Models


"확률 모델을 학습한다"는 것에 대해 잠시 생각해보겠습니다.


모델을 추정한다는 것은 일반적으로 $\cal{P}$라는 진짜 distribution이 저기 어딘가에 있을 때, 이로부터 얻은 sample들이 i.i.d.하게 관측되었다고 가정한 후 $\cal{Q}$라는 parametric 모델 class를 만든 다음 그 안에서 $\cal{P}$와 가장 "가까운" 모델의 parameter를 찾아내는 것을 말합니다.

그러려면 먼저 $\cal{P}$와 $\cal{Q}$의 "차이" 혹은 "거리"를 계산해야겠죠. 두 분포 사이의 거리를 줄여나가는 방향으로 모델 parameter를 조정하는 과정을 모델 fitting이라고 할 수 있습니다.

이렇게 하기 위해 보통 아래와 같은 $\cal{Q}$에 대한 가정들을 사용하곤 하는데요:
  • tractable sampling
  • tractable parameter gradient with respect to sample
  • tractable likelihood function
이 중 가장 강력한 가정이 바로 likelihood function이 tractable하다는 것입니다. 이 부분에 대해서는 VAE에 대한 글에서 열심히 이야기 했었죠. 

한편 GAN-type 모델들은 random variable input을 사용하여 non-linear transformation을 통과시키면 output으로 sample이 튀어나오는 구조입니다. 마치 버튼을 누르면 sample이 튀어나오는 sampler처럼 생각하실 수 있습니다. 

GAN-type 모델들의 독특한 점은 바로 이 부분에서 나옵니다. GAN-type 모델들은 다른 확률 모델들과는 달리 likelihood를 근사하려하지 않고 단지 sample을 뽑아주는 sampler이기 때문에 likelihood-free 모델이라고 할 수 있기 때문입니다. 

Three categories to measure the difference


아무튼 이렇게 어떤 분포 간의 차이를 계산하는 방법은 크게 세 가지로 나눌 수 있는데요 먼저 Integral Probability Metrics (IPM)이라는 방식입니다:
$$\gamma_{\cal{F}}(Q,P)=\sup_{f\in\cal{F}}\left| \int f dQ-\int f dP\right|.$$
여기서 주목할 것은 차이를 계산할 때 각 분포의 expectation 값의 차이를 계산한다는 것입니다. 함수 class인 $\cal{F}$를 어떻게 설정하냐에 따라서 성격이 바뀌는데 예를 들어 Reproducing Kernel Hilbert Space (RKHS)라면 Kernel MMD라는 이름을 갖고 또 다른 형태의 함수 공간을 정의하면 Wasserstein distance이 되기도 합니다. (네 맞아요! WGAN의 Wasserstein distance가 바로 여기에 속합니다.)

두 번째 방식은 Proper scoring rules입니다. 이 용어가 매우 생소하실 수 있겠지만 아마도 maximum likelihood는 매우 익숙하실겁니다. Maximum likelihood 방식이 바로 여기에 해당하는 방식입니다: 
$$S(Q,P)=\int S(Q,x)dP(x).$$
Proper scoring rules는 loss function이라 생각하시면 되는데요 잘 보시면 적분 식 안 쪽에 우리가 분포를 알고 있는 모델 $Q$의 $P$에 대한 fitting score의 expectation을 계산하고 있습니다. 이 score는 $P$와 $Q$가 정확히 일치할 때 가장 최소값을 갖습니다.  

이 역시도 우리가 모델 $Q$를 여러가지 density function으로 바꿔가면서 값을 체크할 수 있고 $P$에 대해서는 expectation을 계산하는 형태라는 것을 기억하시면 됩니다. 

세 번째 종류는 f-divergence를 사용하는 방식입니다:
$$D_f(Q||P) = \int p(x) f(\frac{q(x)}{p(x)})dx.$$
이전과는 다르게 두 분포의 ratio를 계산하여 차이를 구하는 것을 보실 수 있습니다. 적분 식 안에 있는 함수 f 특정 조건을 만족하는 녀석으로 바꾸면 다양한 divergence가 만들어집니다. 

그런데 이 종류의 방식에는 매우 큰 문제가 하나 있는데요 사실 우리가 아는 것은 $Q$의 분포이지 우리가 찾고 싶은 진짜 $P$ 분포는 전혀 알지 못한다는 점입니다. 이전까지는 expectation을 사용했기 때문에 분포를 알지 못해도 expectation은 sample들로부터 계산을 할 수 있었지만 이젠 아니죠. 

이렇게 보면 f-divergence는 사용하는 것이 전혀 불가능한 것처럼 보입니다. 그런데 여기서 GAN을 보는 또 다른 새로운 관점이 나옵니다. GAN이 이런 어려운 방식의 문제(distribution-distribution)를 쉬운 형태의 문제(expectation-expectation)로 바꿔주는 일종의 converter 역할을 해준다는 것입니다. 



이에 대한 부분은 다음 부분을 설명하며 자연스럽게 이해가 되실 것으로 보입니다. 일단은 GAN이 이런 역할을 해준다는 것을 염두에 두고 읽으시면 될 듯 하네요.

f-divergence


먼저 오늘의 주인공인 f-divergence에 대해 좀 더 알아보겠습니다. 이미 언급했지만 적분 안에 f 함수가 특정 조건을 만족할 때 divergence가 된다고 했었죠.

그 조건이란 것이 생각보다 간단합니다. f-divergence를 다시 적어보면 $$D_f(Q||P) = \int p(x) f(\frac{q(x)}{p(x)})dx$$인데 여기서 f함수는 generator function이라 불립니다. 이 때 f는 convex 함수로 $f(1)=0$를 만족해야 합니다. 상당히 직관적인데요. $f(1)=0$가 되려면 $\cal{Q}$와 $\cal{P}$가 같아져야하고 그 때 divergence 값이 0이 나오는 것을 확인하실 수 있습니다. Divergence가 차이를 계산하는 것이라 생각하시면 당연히 만족해야하는 조건입니다.

이 조건만 만족하면 되기 때문에 아주 단순한 형태의 generator function만 갈아끼면 매우 다양한 divergence들을 만들 수 있습니다. 잘 보시면 GAN도 특정 generator function으로 나타낼 수 있다는 것을 미리 보여주고 있습니다:

Estimating f-divergence


하지만 앞서 얘기했다시피 알지 못하는 distribution에는 계산할 수 없으면 아무 소용이 없죠. 다행히도 f-divergence를 써먹기 위해 근사하는 방법에 대해 이미 연구가 되어 있었습니다(Nguyen et al. 2010). 다만 이 연구는 divergence를 근사하는 방법에만 집중을 하고 있고 이를 이용해서 모델을 근사하는 것에는 사용하지 않았던 것을 f-GAN에서 저자들이 잘 가져와 사용한 것입니다 (역시 사람은 두루두루 많이 알고 똑똑해야...). 

이렇게 f-divergence를 estimate하는 부분이 GAN과 직접적으로 연결되는 다리이기 때문에 매우 중요하기 때문에 조금 더 살펴보겠습니다.

먼저 f-divergence의 조건이었던 f는 convex 함수라는 것을 이용합니다. 모든 convex 함수는 다음과 같은 성질을 만족하는데요:
$$f(u) = \sup_{t\in dom_{f^*}}\{tu-f^*(t)\}$$
여기서 $f^*$를 함수에 대한 Fenchel conjugate이라고 부릅니다. 서로 dual이기 때문에 $f^{**}=f$도 성립합니다. 오 이 또 무슨 이상한 수식이냐 하시겠지만 이 녀석이 무슨 말을 하는 것인지 약간의 아이디어를 드리자면 다음과 같습니다: 
"모든 convex 함수 f는 선형(linear) 함수들의 point-wise max로 표현하는 것이 가능하다. "
이 문장도 이해가 안 되신다면 다음과 같은 그림으로 이해해볼 수 있습니다. 


어느 정도 감이 오시나요? 그럼 이제 이 성질을 이용해서 다음과 같이 f-divergence 수식을 풀어나갈 수 있습니다:


첫번째 등호의 식은 단순히 $f(\cdot)$을 정의대로 갈아서 껴준 것에 불과합니다. 다음에 부등호로 넘어가는 것 역시도 그리 어렵지 않습니다. Supremum과 integration의 순서를 바꿀 때 생기는 잘 알려진 부등호로 Jensen's inequality 때문입니다. 약간만 단순화해서 생각해보시면 maximum들의 합이 합의 maximum보다는 당연히 크죠. 이 외에도 arbitrary function class인 $\cal{T}: X \rightarrow \mathbb{R}$가 가능한 모든 함수가 아닌 일부 부분 집합만을 가지고 있을 수도 있습니다. 

어? 그런데 이렇게 lower bound를 만드는 방식 어디서 많이 보셨죠? 네 그동안 자주 보던 variational representation입니다. 이렇게 하한을 만든 다음 이 하한을 maximize하는 방식으로 근사하면 되겠네요. 그리고 이 bound는 다음 식을 만족하기만 하면 tight하다는 것이 밝혀져있습니다:
$$T^*(x) = f'(\frac{p(x)}{q(x)})$$
이 조건이 이후 f 함수와 함수 class $\cal{T}$를 결정하는데 일종의 가이드 역할을 할 수가 있습니다. 예를 들자면, 유명한 reverse Kullback-Leibler divergence의 경우에는$f(u) = − log(u)$를 generator 함수로 갖고 이에 따라 $T^*(x) = −q(x)/p(x)$를 만족하는 녀석들이 함수 class로 정해집니다. 

그런데 이렇게 supremum과 integration의 순서를 굳이 바꿔서 lower bound를 만드는 이유가 있습니다. 이제 적분식이 안으로 들어갔기 때문에...
짠! 세번째 등호와 같이 안쪽이 expectation 형태가 됩니다!!! 어?!! 그러면 이제 더이상 우리가 분포를 몰라도 sample을 가지고 계산을 하는 것이 가능해집니다!! (저는 이런거 볼 때마다 소름이...)

그런데 이렇게 하고 보니까 저 형태 어디서 많이 보던 형태 아닌가요? GAN이랑 매우 닮았죠? (소오오오름!! 저만 그런건가...-_-;;) 분명 우리는 f-divergence를 estimate하기 위해 문제를 풀기 시작했지만 결국엔 GAN의 형태에 도달하는 새로운 방법을 발견하게 된 것이죠. 즉, vanilla GAN을 하나의 special case로 갖는 generalized GAN objective를 제시한 것입니다. Generator 함수 f를 잘 맞춰주면 실제로도 vanilla GAN objective를 만드는 것이 가능합니다. 


그렇기 때문에 앞서 말한 것과 같이 GAN이 어려운 문제(distribution-distribution)를 쉬운 문제(expectation-expectation)로 바꿔서 풀어주는 일종의 convertor 역할을 하고 있었다고 해석할 수 있게 됩니다. 

이 외에도 다양한 common divergence들에 대해서 해당하는 함수들을 마치 look-up table처럼 갈아끼기 좋게 만들어서 appendix(table 5, 6)에 넣어주었습니다. 허허...즉 이제 동작하는 아무 GAN implementation 코드를 받은 다음 거의 메뉴얼 북처럼 보고 두 줄 혹은 세 줄 정도만 바꿔주면 새로운? GAN을 만들 수 있게 된 것입니다. 

이제 어떤 정도의 파괴력을 지닌 논문인지 약간은 감이 잡히시나요. 저는 당시에 아직 뭐도 잘 모르면서도 상당히 신기해서 아마 나중에 동영상을 찾아보시면 아시겠지만 후반부 질문자 중 제 목소리를 들으실 수 있을 겁니다ㅋㅋ

이제 이렇게 일반적인 GAN objective를 제안했으니 이를 풀 알고리즘과 convergence를 증명하는 부분이 남았는데요 이에 대해서는 다음 편에 이어 얘기해보도록 하겠습니다. 

다음 읽을거리



참고문헌:



댓글 4개:

  1. 안녕하세요! 글 너무 잘 봤습니다.
    궁금한 점이 있어서 댓글로 남겨드립니다.
    저기서 divergence가 저런식으로 conjugate function을 써서 나오는건 알겠는데, lower bound가 왜 갑자기 T(x)라는게 도입되면서 형성되는지 갈피를 잡을 수가 없는데 조금만 더 부연설명 부탁드립니다..

    답글삭제
    답글
    1. HS Choi님 댓글 감사합니다. 일단 T(x)가 아니더라도 Jensen's inequality에 의해 저렇게 부등호가 되는 것은 알 수 있습니다. 둘째로 T(x)가 나오기 전에 식을 보시면, supremum 안의 x는 constant이고 즉 x에 대해 pointwise supremum이라는 것을 알 수 있습니다. sup이 밖으로 빠져나오기 위해서는 어찌 되었든 아에 pointwise sup 값을 나타내줄 어떤 x에 대한 함수가 필요합니다. 그것을 T로 갈아 끼어넣었는데 사실 모든 x에 대해 pointwise sup을 정확히 표현해줄 함수 T를 만들어주는게 어렵죠 이에 따라서 자연스럽게 lower bound가 됩니다.

      삭제
  2. 안녕하세요! 글을 읽던 중에, IPM이 expectation의 차이라고 해주셨는데, 식에서 integral f dP 이런식으로 되어있는데 보통 expectation을 계산할때 확률분포를 곱해줘서 적분을 하는데 저기서는 왜 확률분포를 곱하지 않는지 알 수 있을까요?..

    답글삭제
    답글
    1. 안녕하세요 dP = p(x)dx로 정의입니다. 측도론에서 하는 얘기입니다 :) https://www.quora.com/Is-dp-x-p-x-dx-in-measure-probability-theory

      삭제