2017년 1월 19일 목요일

초짜 대학원생의 입장에서 이해하는 Domain-Adversarial Training of Neural Networks (DANN) (1)

Domain-Adversarial Training of Neural Networks -  Yaroslav Ganin 2016 논문을 기본으로 작성한 리뷰. 

Domain Adaptation에 GAN을 적용한 이 논문은 사실 GAN의 original paper보다 먼저 봤는데 어쩌다 보니 순서는 거꾸로 설명을 하게 되었습니다 (...아니 사실 내가 본 순서가 이상하지..-_-)
Domain Adaptation도 GAN에 비견할만큼 주목받고 있는데 이 논문은 심지어 이 분야에 GAN을 적용해 합쳐버린 논문입니다!

이 주제는 전이 학습(transfer learning)이라는 또 다른 역사가 깊고 활발히 연구되고 있는 분야와 직접적인 관계가 있기 때문에 여러 모로 알아두면 좋은 논문.....이라고 해서 저자를 보면 역시나 네임드 Hugo Larochelle이 떡하니 박혀있네요 (무슨 만화 원피스 삼대장 느낌이다..).

본격적으로 논문 내용을 다루기 전에 기본적으로 알면 이해에 도움이 될 배경지식을 잠시 소개하고 가겠습니다.



[Domain Adaptation에 대한 개념적 소개]

Domain Adaptation 역시도 이름을 잘 뜯어보면 뭘 하려는 것인지 대충 감이 옵니다. 어떤 두 개의 domain들이 있으면 한 쪽을 다른 쪽으로 조정하거나 맞추려는(adapt) 것인데요

그럼 자연스럽게 "왜 이런 일을 해주려는 것인데?" 라는 의문이 생깁니다.

답을 하기 위해서 잠시 전이 학습(transfer learning)에 대한 소개가 먼저 필요할 것 같습니다. 학습을 위해서는 가장 기본적으로 필요한 것이 교본(training data, or labeled data)입니다. 그러나 보통 이런 교본은 매우 값이 비싸지요. 때에 따라서는 현실적으로 training data를 만드는 것이 불가능한 경우까지도 있습니다. 게다가 우리가 알고 있는 학습 기법들은 많은 경우 이미 학습된 혹은 training data domain이 test data domain과 비슷한 경우에 효율적으로 동작합니다.

따라서 이런 여러 상황을 고려할 때, 가장 먼저 생각해볼 수 있는 것이 이미 알고 있는 지식을 이용해서 새로운 상황을 학습하는데 사용해보는 것(knowledge transfer)입니다.

벌써 어느 정도 감이 오실텐데요 이런 아이디어를 실제로 적용해본 것이 domain adaptation입니다. 우리가 이미 알고 있는 training data domain의 분포(source knowledge)로부터 시작해서 새로운 test data domain 분포(target knowledge)와 비슷하게 조정해나가면 아무래도 아무 것도 없는 상황에서 새로 학습하는 것보다 좀 더 낫겠죠.

이런 개념들은 이미 1995년 NIPS workshop에서 "Learning to learn"이라는 주제로 소개가 되었을만큼 아주 오랜 역사를 가지고 있습니다. 이해를 돕기 위해 실제로 여러 분야에서 유용하게 연구된 예시들을 보면 다음과 같습니다.

  1. 가장 직관적인 예시로, 영화에 대한 리뷰가 긍정적인지 부정적인지를 판별하는 것을 학습시킨 모델에서 책에 대한 리뷰를 판별하도록 하는 sentiment classification이 있겠습니다. 꼭 영화와 책이 아니더라도 다양한 상품에 대해 개별적으로 학습을 시키려면 대량의 교본(labeled data)가 필요할 것이나 transfer learning을 이용한다면 좀 더 효율적으로 학습을 할 수 있습니다.
  2. 또한 데이터가 자주 갱신되어 시간이 지나면 특정 시간대에 수집한 데이터로 학습한 모델이 다른 시간대에 수집한 데이터에 잘 적용이 안되는 경우에도 비슷한 논리로 적용을 해볼 수 있겠네요.

이렇게 training과 test 분포들이 서로 약간 다른(presence of shift) 환경에서 효율적으로 학습을 하려는 것이 domain adaptation (DA) 입니다. 이 것이 잘 된다면 아주 큰 반향이 있을 수 있는 것이 비지도 학습(unsupervised learning)에 상당한 영향을 줄 수 있기 때문입니다. 세상에는 답이 있는 경우보다 답이 없는 경우가 훨씬 더 많으니까요.

Transfer learning에 대해 좀 더 자세한 survey를 원하는 분들에게는 A Survey on Transfer Learning - SJ Pan et al. 2010 논문을 추천합니다. 세상에나 인용만 2488회라니... 


[Domain-Adversarial Neural Network (DANN) 소개] 


먼저 DANN의 main idea를 요약하자면,

(1) discriminativeness는 유지하면서 (2) domain-invariance 역시도 고려하자

...는 것입니다. 이게 뭔소리냐...하실텐데 일단 그림을 보시죠

shallow DANN 개요도
ppt 그림 그리기...ㅋㅋㅋ

좀 더 풀어 말하자면 우리가 잘하고자 하는 task인 label predictor (classifier)의 역할은 잘하도록 유지하되 우리가 보고 있는 sample의 feature representation이 source domain에서 왔는지 target domain에서 왔는지를 구별 못하게 domain discriminator를 약화하는 방향(gradient reverse @backprop)으로 학습시키겠다는 것입니다.

즉, training data에 대한 classification error는 최소화하되, training data를 classify하기 위해 학습된 feature들이 training data와 test data를 구별 할 수 없도록(두 domain으로부터 나온 data가 feature representation layer를 통과하고 나왔을 때, domain이 구별 안 되게) 하자는 전략입니다. 이렇게 학습된 feature들은 training과 test에 상관없이 robust하겠죠.

사실 이 idea는 Ben-David et al., 2006, 2010 paper에서 소개된 domain adaptation 이론에서 단초를 얻었다고 합니다.
"A good representation for cross-domain transfer is one for which an algorithm cannot learn to identify the domain of origin of the input observation."
"For effective domain transfer to be achieved, predictions must be made based on features that cannot discriminate between the training (source) and test (target) domains."

이 부분에 나오는 알고리즘에 GAN을 사용하면 딱 될 것 같은 느낌이죠  그래서 가져다 사용했다! 는 것이 DANN의 주요 내용 중 하나입니다.

그럼 기존의 DA에서 다른 연구들이 했던 것과는 DANN이 무엇이 다르고 그 것이 어떤 장점이 있는 지에 대해서 간단히 소개하고 본격적으로 이론에 대해 정리해보겠습니다.

보통 이전 연구에서는 fixed feature representation을 찾은 후 이를 사용하여 source와 target domain 간 feature들의 분포를 맞춰주기 위해 source domain의 sample들을 뽑아 reweighing을 하거나 아예 source domain 분포에서 target domain 분포로 mapping해주는 feature space transformation을 찾는 식으로 접근했다면,

DANN 역시도 feature space distribution을 match해주지만 reweighing이나 기하적 transformation을 찾는 것이 아닌 1) feature representation 자체를 바꾸는 방식이며, 2) domain adaptation과 deep feature learning을 하나의 학습 과정 안에서 해결한다는 것이 다릅니다. 또한 3) 분포 간 차이를 측정할 때 deep discriminatively-trained classifier를 사용하는 것도(GAN concept) 매우 다른 점 중 하나입니다.


[Domain Adaptation 이론]

이제 개념은 어느 정도 소개한 것 같으니 좀 더 구체적으로 문제를 정의해보겠습니다.

$X$ : input space, $Y = \{0,\cdots, L-1\}$ : set of $L$ possible labels일 때, $L$개의 class를 분류하는 문제를 생각해보겠습니다. DA 문제는 여기에 더하여 다음과 같이
$$\mathbb{D}_S: source~domain~~~~\mathbb{D}_T: target~domain$$
two different distributions over $X\times Y$이 추가로 있는 경우입니다.
따라서 비지도 DA 학습일 경우 우리는 $S$에 대해서는 label이 존재하지만 $T$에 대해서는 label이 전혀 없습니다.
$$ S = \{ (x_i,y_i) \}_{i=1}^n  \overset{i.i.d.}{\sim} (\mathbb{D}_S)^n;~~T=\{x_i\}_{i=n+1}^N  \overset{i.i.d.}{\sim} (\mathbb{D}_T^X)^{n'}$$
with $N = n+n'$ being the total number of samples.

따라서 이 논문에서 소개하는 알고리즘의 목표는 분류기(classifier) $\eta: X\rightarrow Y$를 만드는 데, $\mathbb{D}_T$의 label에 대한 정보가 전혀 없이도 target risk:
$$R_{\mathbb{D}_T}(\eta)=\Pr_{(x,y)\sim\mathbb{D}_T}\left( \eta(x) \neq y \right),$$ 가 낮도록 학습시키는 것입니다. 즉, test time에서도 잘 되는 classifier를 만들자는 것이죠.

Domain Adaptation 문제를 다루기 위해 먼저 갖춰야할 도구가 있는데 domain divergence 등과 같이 source와 target 분포 간의 distance를 구하는 measure입니다. 보통 DA 문제를 푸는 대다수의 전략이 target error를 source error와 domain divergence의 합으로 upper bound 시키는 방식이기 때문에 그렇습니다.

생각해보면 매우 직관적인 전략이죠. source risk가 target risk의 good indicator가 되려면 일단 두 분포들이 비슷해야할테니 말입니다.

하여 이제 본격적으로 이론에 대해 소개를 해야하는데...잉여력이 부족하기도 하고 너무 길어지는 것 같으니 다음 글에서 이론 소개 이후 간단한 예시 문제(two moon data)를 MATLAB으로 짠 코드와 함께 풀어보겠습니다.

다음 읽을거리

참고문헌:



댓글 6개:

  1. algorithm을 GAN으로 사용하면 딱 될 것 같다고 설명 하셨는데 GAN의 어떠한 특성 때문에 그런 방향의 결론이 나올까요? 궁금합니다!

    답글삭제
    답글
    1. unknown님 안녕하세요. 음 GAN의 개념이 image generation하는 NN과 discriminator NN이 서로 경쟁하게 하자는 것입니다. Generator는 discriminator가 구별 못하는 image를 만들어내는 것이 목표이고 discriminator는 어떻게는 이 image가 real image인지 혹은 generator가 진짜 image처럼 모사한 fake image인지를 구별하려고 합니다. 그럼 여기 DA에서처럼 image sample을 뽑아내는 generator의 분포가 real data의 분포와 동일해지도록 adapt시킬 수 있다면 우리의 목적이 달성되겠지요. http://jaejunyoo.blogspot.com/2017/01/generative-adversarial-nets-1.html의 두번째 그림을 보시고 그 밑에 설명을 보시면 좀 더 이해가 쉬우실 것 같습니다. 답변이 도움이 되었나요?

      삭제
  2. Feature extractor에서 나온 feature들은 Source 와 Target domain의 공통된 부분인 건가요?

    답글삭제
    답글
    1. Unknown님 안녕하세요. 네 맞습니다. 다만 좀 더 자세히 말하자면, 진짜 목적인 classification을 잘 하기 위해 여러 feature들을 사용할 수 있을텐데 그 중에서도 Source와 Target domain을 구별할 필요없는 feature들을 사용하여 classification을 하는 식으로 학습이 되겠죠.

      삭제