2018년 5월 9일 수요일

[Paper Skim] Spectral Normalization for Generative Adversarial Networks

Spectral Normalization for Generative Adversarial Networks

TL;DR: A novel weight normalization technique called spectral normalization to stabilize the training of the discriminator of GANs.
Keywords: Generative Adversarial Networks, Deep Generative Models, Unsupervised Learning
Accept: (Oral)
Rating: 8-8-8
Review: https://openreview.net/forum?id=B1QRgziT-

1. Introduction


Preferred network 그룹에서 나온 논문. (최근 핫한 일본 그룹) 그리고 Ian Goodfellow의 홍보 (보증?...) 개인적으로 매우 취향인 논문. (이후 더 자세히 리뷰 예정) 


GANs를 안정적으로 학습시키는 것을 새로운 weight normalization으로 해결해보고자 함. Spectral normalization이라 불리는 이 방법은,
  • Intensive hyper parameter이 필요없음. Lipshitz constant가 유일한 hyperparameter to be tuned. (심지어는 tuning 안 해도 잘 됨)
  • Implementation이 단순하고 computational cost가 적음. 
Batch normalization이나 weight decay, feature matching on the discriminator와 같은 regularization tech.가 없이도 working 잘 함.

2. Spectral Normalization


각 레이어의 spectral norm을 제약함으로써 Discriminator function $f$의 Lipschitz constant를 컨트롤 함. 

ReLU와 같은 activation function의 Lipschitz norm은 1이기 때문에 네트워크 전체를 볼 때 고려하지 않아도 되고, 결국 Weight의 Lipschitz norm을 나눠줌으로써 각 weight matrix $W$의 Lipschitz constant $\sigma(W)=1$:
$$\bar{W}_{SN}(W):=W/\sigma(W).$$

이를 바탕으로 $||f||_{Lip}$가 1로 상계를 갖도록(upper bounded) 함.

Gradient Analysis of the Spectrally Normalized Weights


The gradient of $\bar{W}_{SN}(W)$ w.r.t. $W_{ij}$:
\begin{align} \frac{\partial\bar{W}_{SN}(W)}{\partial W_{ij}}  &= \frac{1}{\sigma(W)}E_{ij} - \frac{1}{\sigma(W)^2}\frac{\partial \sigma(W)}{\partial W_{ij}}W \\&= \frac{1}{\sigma(W)}E_{ij} - \frac{[u_1v_1^T]_{ij}}{\sigma(W)^2}W \\&= \frac{1}{\sigma(W)} (E_{ij} - [u_1v_1^T]_{ij}\bar{W}_{SN}) \end{align}
여기서 $E_{ij}$는 $(i,j)$-th entry는 1 나머지는 0인 행렬이고 $u_1$과 $v_1$이 first left and right singular vecotrs of $W$. $h$를 hidden layer라고 하면 아래가 성립함:
\begin{align}\frac{\partial V(G,D)}{\partial W}&=\frac{1}{\sigma(W)}(\hat{E}[\delta h^T]-(\hat{E}[\delta^T\bar{W}_{SN}h])u_1v_1^T)\\
&= \frac{1}{\sigma(W)}(\hat{E}[\delta h^T]-\lambda u_1v_1^T) \end{align}
여기서 $\delta:=(\partial V(G,D)/ \partial(\bar{W}_{SN}h))^T, \lambda:=\hat{E}[\delta^T(\bar{W}_{SN}h)]$이고 $\hat{E}[\cdot]$은 각 미니 배치의 empirical expectiation을 나타냄.
For some $k\in \mathbb{R}$, $\hat{E}[\delta h^T]=ku_1v_1^T$일 때 $\frac{\partial V}{\partial W}=0$이 성립함.

여기서 식 (5)의 해석이 매우 재미있는데, 식의 첫번째 항은 normalize되지 않은 weights에 대한 미분이므로 별다를 것이 없고 두번째 항이 추가된 것으로 생각해보면, 이를 adaptive regularization coefficient $\lambda$만큼 첫번째 singular component를 penalize하는 regularization 항으로 본다면 다음과 같은 해석이 가능함:

$\lambda$가 양수라는 얘기는 $\delta$와 $\bar{W}_{SN}h$가 비슷한 방향을 가르키고 있다는 것을 의미함. 즉, $W$의 column space가 한 쪽 방향으로만 집중해서 update되는 것을 막아준다고 해석할 수 있음. 논문에서는 이를 통해 spectral normalization이 네트워크의 각 layer가 한 방향으로만 sensitive하지 않도록 막는다고 얘기함.

3. Spectral Normalization vs Other Regularization Techniques


Weight normalization은 결과적으로 너무 강한 constraint를 걸어버리는 경향이 있음. Weight normalization은 weight matrix의 rank를 1이 되도록 강제함 (matrix norm과 weight normalization definition에 의해 수식을 보면 확인할 수 있음). 

그런데 이렇게 하면 discriminator가 하나의 feature만을 보고 probability distribution을 구별해야하기 때문에 discriminator가 매우 sensitive하고 unstable하게 만드는 경향이 있음.

Orthonormal regularization on each weight는 spectral normalization과 유사하면서도 학습을 안정화해주기는 하지만,
$$||W^TW-I||_F^2$$
weights를 orthonormal하게 하므로써 (모든 singular value를 1로 강제하기 때문에) spectrum의 envelop을 망치고 중요한 정보를 잃어버리는 경향이 있음. Spectral normalization은 spectrum의 scale만을 조절하기 때문에 (최대 값을 1) 이와는 다름.

GP와 같은 경우는 위에서 설명한 다른 normalization tech.들과 같은 문제는 없지만 현재 generative distribution의 support에 매우 강하게 엮여있다는 약점이 있음. 이 때문에 학습이 진행됨에 따라 generative distribution의 support도 바뀌기 때문에 학습 과정이 불안정적이 된다는 단점이 생김. Spectral normalization은 학습하는 함수를 operator space에서 regularize하기 때문에 들어오는 데이터 batch에 보다 덜 민감한 것을 볼 수 있음.


4. Experiments


최초로 단일 네트워크로 이미지넷 1000개 범주의 이미지를 생성한 방법인 것만으로도 큰 의미를 지님.







댓글 3개:

  1. 안녕하세요? 좋은 글 감사합니다. 질문이 있어 글을 남깁니다. "Weight normalization은 weight matrix의 rank를 1이 되도록 강제함"이라고 설명하셨는데, weight normalization을 적용할 때 Frobenius norm을 적용할 때와 column/row vector의 norm을 이용해 weight normalization을 할 수도 있는데 모두 동일하게 rank를 1로 강제하는지, 경우에 따라 다른 건지 궁금합니다 (rank 제약에 대해 조금 더 친절한 설명 주시면 감사하겠습니다).

    답글삭제
    답글
    1. Equation 3에서 닫힌 괄호가 누락되었습니다.

      삭제
  2. 방법론 자체는 그냥 weight의 표준편차로 나눠주기만 하면 되는 건가요?

    답글삭제