2017년 1월 29일 일요일

초짜 대학원생의 입장에서 이해하는 Domain-Adversarial Training of Neural Networks (DANN) (3)

Domain-Adversarial Training of Neural Networks -  Yaroslav Ganin 2016 논문을 기본으로 작성한 리뷰.

 

[DANN으로 toy problem 풀어보기]

* Shallow Neural Network & two moon dataset

공부를 할 때 가장 기본은 많이 읽는 것으로 시작해서 그림으로 생각하고, 마지막으로는 코딩을 통해 직접 이론을 implement를 하여 결과를 시각화하는 것이 최고라고 생각합니다. 거기에 더해 요즘은 블로그로 글 정리를 해보고 있는데 이 것도 스스로 공부가 많이 되네요ㅎㅎ

아무튼 앞에서 이론은 열심히 봤으니까 이젠 코딩을 해봅시다. 지난 글 (1), 글 (2)에서 다뤘던 Domain Adaptation (DA) 이론을 hidden layer 수가 한 개씩으로 얕은 Neural Network를 이용하여 간단한 clustering 문제를 풀어보도록 하겠습니다. 본 글에 나오는 예제 코드는 MATLAB으로 작성하였으며 제 github에 올려두었습니다.

* MATLAB보다 Python이 편하신 분은 제 github에 Window 7,10 Python 3.5 tensorflow-gpu 1.0.0.rc2 버전에서 돌아가는 Jupyter notebook으로 올려두었으니 보시면 됩니다. 
다만 본 글은 수식과 연관하여 이해를 돕기 위해 논문의 pseudo code를 그대로 적용한 MATLAB 코드를 기준으로 설명을 진행하겠습니다. 

사실 이 역시도 논문에 포함되어있는 내용이고, 심지어 저자가 toy code를 python으로 작성하여 github에 올려두었지만 제가 MATLAB language로 옮겨 짜면서 새로 data generation 함수를 추가하였고 main 역시도 더 간단하게 딱 실험에 필요한 부분만 남겨 이해를 도울 수 있도록 요약해보았습니다.

먼저 해결하고자 하는 문제를 확인해보겠습니다.
Clustering 문제에서 많이 보이는 예제 dataset인데, 초승달 한 쌍이 서로 마주 보고 있는 모양이라 two-moon dataset이라 불립니다.
two-moon dataset

위 그림이 two-moon dataset인데요 제 github의 코드 중 'gen_2moon.m' script를 실행하면 만들 수 있습니다. 'o'와 '+'로 두 가지 class가 있는 것을 보실 수 있습니다. 위에 있는 반원이 'o' 아래 있는 반원이 '+'로 label이 붙어있는데 이 역시도 파란 색과 빨간색으로 나뉘어 있습니다. 빨간색target domain으로써 source domain인 파란색 데이터를 약간(35º) 회전하여 만들었습니다. 즉, source domain으로 학습을 하되 label이 없는(우리는 알지만 NN은 label을 모르는) target domain에 대해서도 잘 classify하도록 domain adaptation 해보겠습니다.

DANN은 다음과 같은 구조를 가지고 있습니다: 

shallow DANN 개요도

* 그림 재활용...ㅋㅋㅋ

앞서 글들에서 설명했듯이 Feature Extractor, Classifier, 그리고 Domain Discriminator 크게 세 가지 부분으로 나누어지고, backpropagation을 할 때 adversarial의 형태로 학습을 시켜주기 위하여 gradient reversal layer가 들어가 있는 것을 보실 수 있습니다.

이를 수식으로 나타내면 다음과 같이 source domain에 대한 classification loss와 optional regularizer $R(W,b)$로 이루어진 optimization problem이 됩니다:
$$\min_{W,b,V,c}\left[ \frac{1}{n}\sum_{i=1}^n \mathcal{L}_y^i(W,b,V,c)+\lambda\cdot R(W,b)\right].$$
Note: $\mathcal{L}_y^i(W,b,V,c) = \mathcal{L}_y(G_y(G_f(x_i;W,b);V,c),y_i)$; $i$-th example prediction loss.

다만, DANN은 $R(W,b)$에 약간 독특한 domain regularizer를 사용하여 $\cal{H}$-divergence를 줄이는 방향으로 하기 때문에 통상적인 minimization 문제와는 약간 달라집니다.

차례차례 확인해보겠습니다. $G_f(\cdot)$이 해주는 역할은 sample을 어떤 hidden feature space의 한 점으로 표현해주는 것입니다. 따라서 각각의 domain으로부터 온 sample의 feature representation을
$$S(G_f) = \{G_f(x)|x\in S\},~~~T(G_f) = \{G_f(x)|x\in T\}$$
이라 할 때, empirical $\cal{H}$-divergence는

$\hat{h}_{\mathcal{H}}(S(G_f),T(G_f))$ $$= 2\left(1-\min_{\eta\in\mathcal{H}} \left[\frac{1}{n}\sum_{i=1}^n I \left[ \eta(G_f(x_i))=1\right] + \frac{1}{n'}\sum_{i=n+1}^N I \left[ \eta(G_f(x_i))=0\right]\right]\right). $$
라고 얘기할 수 있습니다. 우리가 원하는 것은 feature가 empirical $\cal{H}$-divergence 값이 줄어드는 방향으로 학습되는 것(domain independent feature)이기 때문에 지난 글 (2)에서도 강조하였었지만 minimization 안의 classifier가 제대로 역할을 못하여 error가 크도록 해야합니다:
$$R(W,b) = \max_{u,z}\left[ -\frac{1}{n}\sum_{i=1}^n\mathcal{L}^i_d(W,b,u,z) - \frac{1}{n'}\sum_{i=n+1}^N\mathcal{L}^i_d(W,b,u,z)  \right],$$
where $\mathcal{L}_d^i(W,b,u,z) = \mathcal{L}_d(G_d(G_f(x_i;W,b);u,z),d_i).$
즉, discriminator의 weight와 bias인 $u,z$에 대해 $\max(-1 \times loss)$라는 것은 $\min(loss)$와 같은 방향이므로 adversarial하게 discriminator를 강화하되 $W,b$에 대한 regularization을 줄 때는 divergence가 minimize하는 방향으로 학습합니다:

$E(W,V,b,c,u,z)$ $$=\frac{1}{n}\sum_{i=1}^n\mathcal{L}^i_y(W,b,V,c)-\lambda\left( \frac{1}{n}\sum_{i=1}^n\mathcal{L}^i_d(W,b,u,z) + \frac{1}{n'}\sum_{i=n+1}^N\mathcal{L}^i_d(W,b,u,z) \right),$$
where we are seeking the parameters $\hat{W}, \hat{V},\hat{b},\hat{c},\hat{u},\hat{z}$ that deliver a saddle point given by

$$\begin{align} &(\hat{W}, \hat{V},\hat{b},\hat{c}) = \underset{W,V,b,c}{arg\min} ~E(W,V,b,c,\hat{u},\hat{z}),
\\ &(\hat{u},\hat{z}) ~~~~~~~~~~~= \underset{\hat{u},\hat{z}}{arg\max} ~E(\hat{W},\hat{V},\hat{b},\hat{c},u,z). \end{align}$$
$(\hat{W}, \hat{V},\hat{b},\hat{c})$에 대해서는 $\min~E(\cdot)$이지만 $(\hat{u},\hat{z})$에 대해서는 $\max~E(\cdot)$입니다. 위에 빨간색으로 강조한 글과 비교해가면서 읽어보시면 됩니다. Minimization과 maximization이 섞여서 상당히 헷갈리실테지만 시간을 두고 곰곰히 생각해보시는 것이 좋습니다. (이 논문을 쓴 친구는 설명을 그리 잘 하는 편은 아닌 듯 합니다. 게다가 저는 이것에 합쳐서 저자의 critical 오타로 인해 한참 시간을 잡아 먹었습니다...orz)

이제 전략이 구체적으로 세워졌으니 논문에 있는 pseudo-code를 크게크게 쪼개어 살펴보며 각 부분에 왜 그런 식의 코드가 들어가 있는지 해설하는 방식으로 진행해보겠습니다.

먼저 통상적으로 'o'와 '+' 두 class를 구별하는 것을 학습하는 부분입니다.
1: Input:
    samples $S =\{(x_i,y_i)\}^n_{i=1}$ and $T=\{x_i\}^{n'}_{i=1}$,
    hidden layer size $D$,
    adaptation parameter $\lambda$.
    learning rate $\mu$,
2: Output: neural network $\{W,V,b,c\}$
3: $W,V \leftarrow$ random_init($D$)
4: $b,c,u,z \leftarrow 0$
5: while stopping criterion is not met do
6:   for $i$ from $1$ to $n$ do
7:      # Forward propagation
8:      $G_f(x_i) \leftarrow$ sigm$(b+Wx_i)$
9:      $G_y(G_f(x_i)) \leftarrow$ softmax$(c+VG_f(x_i))$
10     # Backpropagation
11:    $\Delta_c\leftarrow -(e(y_i)-G_y(G_f(x_i)))$
12:    $\Delta_V\leftarrow \Delta_c G_f(x_i)^T$
13:    $\Delta_b \leftarrow (V^T \Delta_c) \odot G_f(x_i)\odot(1-G_f(x_i))$
14:    $\Delta_W \leftarrow \Delta_b \cdot (x_i)^T$
$E(W,V,b,c,u,z)$에서 첫 번째 부분인 $\frac{1}{n}\sum_{i=1}^n\mathcal{L}^i_y(W,b,V,c)$ 를 풀고 있다고 생각하시면 됩니다. 왼쪽에서 input으로 source domain으로부터 뽑은 sample $x_i$이 들어가면, feature extractor 역할을 하는 $G_f$ layer를 지나 classifier layer $G_y$에서 이 sample의 class $y_i$가 'o'인지 '+'인지 구별하고 error를 계산한 다음 backpropagation을 통해 통상적으로 Weight $(W,V)$와 bias $(b,c)$를 update합니다.

제 backpropagation 글을 보시면 마지막 부분에 학습하는 과정을 수식으로 유도해두었습니다. 본 글에서의 notation을 이 틀에 그대로 적용하면 라인 11~14까지의 수식이 자연스럽게 나옵니다. Sigmoid를 사용한 것과 layer 수까지 같기 때문에 계산을 대입하는 것이 그리 어렵지 않을 것입니다. 그 글에서의 notation과 비교하자면 $e(y_i)$가 $t_i$ $x_j$는 $G_f(x_i)$ 정도인 것을 염두에 두고 보시면 됩니다.

Note: 여기서 $e(y)$는 "one-hot" vector라 해서 y class label 위치에는 1 나머지에는 모두 0인 벡터를 의미하고, $\odot$은 element-wise product입니다.

다음으로 Domain Adaptation을 해주는 부분입니다.
15:     # Domain adaptation regularizer ... 
16:     # ... from current domain
17:    $G_d(G_f(x_i)) \leftarrow$ sigm$(d+u^TG_f(x_i))$
18:     $\Delta_z \leftarrow \lambda(1-G_d(G_f(x_i)))$
19:     $\Delta_u \leftarrow \lambda(1-G_d(G_f(x_i)))G_f(x_i)$
20:     tmp $\leftarrow \lambda (1-G_d(G_f(x_i))) \times u \odot G_f(x_i) \odot (1-G_f(x_i))$
21:     $\Delta_b \leftarrow \Delta_b+$ tmp
22:     $\Delta_W \leftarrow \Delta_W +$ tmp $\cdot (x_i)^T$
23:     # ... from other domain
24:     $j \leftarrow $ uniform_integer$(1,\cdots,n')$
25:     $G_f(x_j) \leftarrow $ sigm$(b+Wx_j)$
26:     $G_d(G_f(x_j)) \leftarrow $ sigm$(d+u^T(G_f(x_j))$
27:     $\Delta_z \leftarrow \Delta_z -\lambda G_d(G_f(x_j))$
28:     $\Delta_u \leftarrow \Delta_u -\lambda G_d(G_f(x_j))G_f(x_j)$
29:     tmp $\leftarrow -\lambda G_d(G_f(x_j)) \times u \odot G_f(x_j) \odot (1-G_f(x_j))$
30:     $\Delta_b \leftarrow \Delta_b +$ tmp
31:     $\Delta_W \leftarrow \Delta_W +$ tmp$\cdot(x_j)^T$
$E(W,V,b,c,u,z)$함수의 regularization part라고 생각하시면 됩니다. 두 번째 부분인 $-\lambda\left( \frac{1}{n}\sum_{i=1}^n\mathcal{L}^i_d(W,b,u,z) + \frac{1}{n'}\sum_{i=n+1}^N\mathcal{L}^i_d(W,b,u,z) \right)$를 풀고 있는 것이죠.

일단 target domain로부터 뽑은 sample $x_j$를 현재의 $G_f$ layer에 통과시켜 feature representation $G_f(x_j)$를 계산하고 source domain sample $x_i$의 feature representation $G_f(x_i)$를 $G_d$를 통해 구별합니다. 라인 18과 27부터 각각의 domain에 대한 discriminator loss를 backpropagate하는 부분이 시작되는데요 domain이 binary classification이고 discriminator의 output은 "domain이 source로부터 왔을 확률"이기 때문에(sigmoid 함수 output) one-hot vector 대신 1과 0이라는 constant 값이 label로 들어가 있죠. 여기서 $\lambda$는 regularization의 trade-off hyperparameter입니다.

앞서 설명한 $E$ 함수에 domain regularization 부분이 -$\lambda$로 시작하기 때문에 라인 11과는 달리 라인 18은 괄호 바깥의 $-$가 상쇄되어 없어진 것을 확인할 수 있습니다. 중요한 점은 여기서 $\Delta$가 붙은 것들은 모두 gradient이기 때문에 함수를 maximization하는 방향임을 염두에 두셔야합니다.
$$\begin{align} &(\hat{W}, \hat{V},\hat{b},\hat{c}) = \underset{W,V,b,c}{arg\min} ~E(W,V,b,c,\hat{u},\hat{z}),
\\ &(\hat{u},\hat{z}) ~~~~~~~~~~~= \underset{\hat{u},\hat{z}}{arg\max} ~E(\hat{W},\hat{V},\hat{b},\hat{c},u,z). \end{align}$$
우리가 푸는 문제가 현재 이와 같기 때문에 gradient 값들을 위와 같이 다 계산해서 합치고 나서 각 변수들을 update해줄 때 방향이 중요합니다. $(W,V,b,c)$들에 대해서는 minimization 문제였으므로 gradient의 반대 방향으로 update해주기 위해 값을 빼주고 반대로 $(u,z)$에 대해서는 maximization 문제기 때문에 값을 그대로 더해줍니다.

32:     # Update neural network parameters
33:     $W \leftarrow W - \mu \Delta_W$
34:     $V \leftarrow V - \mu \Delta_V$
35:     $b \leftarrow b - \mu \Delta_b$
36:     $c \leftarrow c - \mu \Delta_c$
37:     # Update domain classifier
38:     $u \leftarrow u + \mu \Delta_u$
39:     $z \leftarrow z + \mu \Delta_z$
40:   end for
41: end while

드디어 다 살펴 보았습니다. 본 논문의 pseudo-code에서는 제가 여기서 쓴 변수 $z$(예: 라인 39)r가 $d$로 되어있을 것입니다. 앞서 $E$ 함수 등 이론 부분에서는 $z$ notation 사용하기에 일관성을 위해서 $z$로 통일하였습니다. 논문에서는 두 변수를 자꾸 섞어 사용해서 헷갈리더군요.

결과를 비교해보면 다음과 같습니다. 제 github의 코드 중 'sDANN_main.m' script를 바로 실행시키시면 adversarial part가 따로 feedback이 되지 않는 shallow Neural Network (sNN)의 classification 결과가 나옵니다. sNN 대신 shallow DANN 결과를 보고 싶으시면 같은 script의 라인 20의 adversarial_representation = false; 부분을 true로 바꾸시면 됩니다.
sNN과 sDANN target domain 분류 결과 비교

각각에 대해 총 3열의 결과가 나옵니다. 먼저 제일 윗 줄은 whole view로 빨간색이 'o' 파란색이 '+'로 분류된 점들을 보여줍니다. 이미 sDANN이 잘 되었다는 것이 확연히 보입니다. 두 번째 줄과 세 번째 줄의 결과는 각각 'o'의 ground truth와 '+'의 ground truth를 초록색으로 겹쳐서 각각의 class label에 대해 결과가 어찌 나왔는지 보여준 것입니다.

각 초승달의 양 끝 부분에 헷갈릴 수 있는 부분을 sDANN이 훨씬 잘 분류하는 것을 볼 수 있는데요...사실 이 실험은 제가 random seed를 고정해두었습니다. Seed에 따라 결과가 별 차이가 없는 경우도 있습니다. 처음 implement 했을 때는 논문처럼까지는 결과가 잘 나오지 않아 당황했었는데...하필 default seed가 그랬던 것이고 다른 seed에 대해서는 훨씬 결과가 좋은 경우도 많습니다...(나중에 보니 원 저자도 seed를 고정해뒀더군요.)

코드가 읽는 것이 어렵지는 않을테니 (특히나 MATLAB이 편하신 공돌공돌 분들께는 더욱ㅎㅎ) 읽어보시고 여러 방식으로 바꿔보시는 것을 추천드립니다.

이상으로 길었던 DANN 리뷰를 마치겠습니다. 다음에는 뭘 해볼까요? ㅎㅎ

다음 읽을거리

참고문헌:




댓글 8개:

  1. 좋은 글 감사합니다. 항상 잘 읽고 있습니다.
    이런식으로 이해하면 될까요?

    At: domain A 의 training data
    Bt: domain B 의 training data

    1. DANN의 형태는 generator(G) 에 2개의 discriminator(D1, D2)가 달려있는 것
    2. D1은 At의 classification 담당 (G는 A를 속일 수 있을 만큼 real한 data를 generate하게, D는 fake와 real을 잘 구분하게 학습 됨)
    3. D2는 At와 Bt의 domain adaptation을 담당
    3-1. D2는 At와 Bt의 H divergence를 최소화해야 함
    3-2. 이를 위해 D2는 input의 domain을 잘 구분하지 못 하게(At가 들어오면 B로, Bt가 들어오면 A로 classification), G는 At가 들어오면 Bt처럼 보이는 아웃풋을 generate하게 학습 됨.


    이 정리가 맞다면,
    굳이 H-divergence를 쓰는 이유가 뭘까요? D2를 그냥 일반 GAN에서의 discriminator처럼 At 와 Bt의 도메인을 잘 구분할 수 있게 하는 network로 학습시켜도 3-2에서 마지막 결론인 "G는 At가 들어오면 Bt처럼 보이는 아웃풋을 generate하게 학습 됨."을 얻을 수 있지 않나요?

    답글삭제
    답글
    1. 안녕하세요 이준영님 댓글 감사합니다. 음..먼저 2번부터 얘기를 해야할 것 같네요. D1이 해주는 일은 그냥 classification입니다. D1이 하는 일은 fake와 real을 잘 구분하게 학습하는 것이 아니죠. 그리고 generator라 하시면 용어가 좀 오히려 의미를 헷갈리게 할 수 있는데 DANN에서 앞단의 NN은 feature extractor 역할을 합니다. 즉 D2가 없다고 생각하시면 그냥 일반적인 CNN과 다를게 없습니다. D1에서 backprop하면 좀 더 classification을 잘하도록 feature extractor가 업데이트 되겠죠.

      삭제
    2. 3번에서 D2가 해주는 일이 source(real)와 target(fake)를 구별하는 GAN의 역할입니다. 즉 D1을 떼어내면 오히려 GAN 형태가 되겠네요. 그렇기 떄문에 마지막에 말씀하신게 정확히 DANN이 해주고 있는 일입니다. GAN도 divergence를 minimize하는 것으로 식을 풀어낼 수 있어요.

      삭제
  2. 조금 더 추가하면, D2가 domain을 잘 구분하지 못하게 학습돼야 한다고 했는데, A를 B로, B를 A로 classification하게 학습 시키는 것이 정말 구분하지 못하게 학습하는 건지.. 하는 궁금증도 드네요.
    어떻게 보면 일반 GAN은 D가 A,B를 잘 구분하게 하고 G가 D를 속이게 하자(B처럼 보이게 하자)가 컨셉이라면 DANN은 D가 A,B를 반대로 구분하게 하고 G가 D를 속이게 하자(B처럼 보이게 하자)가 컨셉이 아닌가 하는 생각도 들고요. 생각 듣고 싶습니다.

    감사합니다.

    답글삭제
    답글
    1. 앞서와 비슷한 답변을 드리게 되는데요. D1을 떼어내면 GAN과 형태가 같아진다고 볼 수 있습니다. 그리고 DANN의 D도 A,B를 제대로 구분하게 합니다. Gradient reverse가 있기 때문에 G부분에 해당하는 feature extractor가 A,B를 잘 구별하지 못하게 하는(정확히는 A,B domain 차이에 independent한) feature를 찾는 방향으로 update되고 이게 결국엔 G가 D를 속이게 하는 것입니다.

      삭제
    2. 답글 감사합니다 :)

      먼저 D1에 대한 생각이 완전히 틀렸었군요!
      DANN의 main point가 domain independent한 feature를 생성한다는 점이 더 확실히 이해가 됩니다.

      다만, 답변에 DANN의 D2가 A, B를 제대로 구분하게 한다고 언급하셨는데,
      max(E)를 하기 위해서는 R(W,B)를 maximize 해야 하는 거고, 이는
      논문의 식(7) 아래 정의된 -L_d() 함수를 maximize하는 것과 같은데(이는 언급하신 것처럼 L_d()를 minimize하는 것과 같고요), -L_d() 함수가 maximize되기 위해서는 D_2가 A를 B로, B를 A로 "반대로" 구분해야 되는것 아닌가요? maximize니까 gradient reverse가 쓰이는 거고요.

      제가 말씀드리고 싶은건 "D2가 A,B를 제대로 구분하는데 gradient reverse때문에 반대로 되는게 아니라, D2가 A,B를 반대로 구분하는데 이를 위해선 maximize를 해야하고, 때문에 gradient reverse가 쓰일 수 밖에 없다." 이런 정리도 괜찮은가 입니다.

      마지막으로 일반 GAN의 Discriminator와 DANN의 D2와 다른점을 생각해 봤는데, GAN의 D는 G가 A를 B처럼 보이게 유도 하는거라면, DANN의 D2는 G가 A를 B처럼, B를 A처럼 보이게 유도 하는 것인데,, 일반 GAN과 다르게 "B를 A처럼" 부분이 추가 된 거잔아요?
      이 부분이 왜 domain independent한 feature를 만드는지 직관적으로 와닫지 않네요. A를 B처럼 바꾸고도 classification이 잘 되고, B를 A처럼 바꿔도 classification이 잘 되게 유도하는 거니까.. A와 B의 중간점을 찾는걸까요? 이 부분의 재준님의 생각이 궁금합니다.

      쓰다보니 질문이 계속 생겨서.. 너무 두서 없이 써서 죄송합니다.

      항상 정말 감사합니다!

      삭제
  3. 질문이 있습니다.
    regularization을 해줄 때 왜 max(-loss) 형태로 정의했을까요?
    regularization에 divergence를 반영해주고 싶은 것은 이해했는데
    max(-loss) 형태로 왜 했을지가 궁금합니다

    답글삭제
  4. pseudo-code에서 tmp부분이 어떻게 유도가 된 것인지 궁금합니다
    Regularization을 어떤 것에 대해 미분한 것인가요?
    (d와 u 각각에 대한 미분값까지는 이해했는데 tmp부분이 잘 이해가 안가네요ㅠㅠ)

    답글삭제