2017년 7월 2일 일요일

초짜 대학원생 입장에서 이해하는 [CVPR 2017] Learning by Association - A versatile semi-supervised training method for neural networks

* Learning by Association - A versatile semi-supervised training method for neural networks, Philip Haeusser et al. 2017을 바탕으로 한 리뷰

이번 CVPR 2017에 Accept 되었다는 Learning by Association 논문을 리뷰해보겠습니다. 짧은데 재미있는 논문이에요. 

아이디어는 상당히 직관적이고 간단합니다. 구글에 "Learning by Association"이라고 검색하면 가장 먼저 뜨는 것은 사실 인지 심리학 쪽의 자료들입니다. 저자들도 사람이 "Association"을 사용하여 학습을 한다는 내용에서 착안했다고 하는군요. 

이게 무슨 말이냐 하면, 사람은 기계와는 달리 자료 간의 "연관성"을 파악하여 학습할 수 있기 때문에 매우 많은 예제를 보지 않더라도 학습을 잘 할 수 있다는 것입니다. 
Learning by Association 예시 ([2],[3])

즉, 아이가 처음 "개"라는 것을 알게될 때 몇 가지 예시를 보고 나면, 후에 진돗개를 처음 보더라도 연관성을 바탕으로 유추를 하여 "개"라는 것을 자연스럽게 안다는 것이죠.

그렇다면 이런 "연관성"을 기계학습에도 가져와서 사용할 수는 없을까?라는 것이 이 논문의 주된 아이디어입니다. 만약 이런 학습이 가능하다면 label을 얻는 것이 매우 비싸거나 (의료 영상 데이터) 아주 적은 양만 label이 있는 데이터에도 학습을 잘 할 수 있겠죠. 즉 이 논문의 주안점은 semi-supervised learning을 좀 더 잘 해보자는 것입니다. 

Overview


자 그러면 이렇게 추상적인 개념인 "연관성"을 어떻게 기계학습에 적용을 했는지 살펴볼까요? 먼저 이 논문에서 설정한 가정 하나를 보겠습니다.

"네트워크가 embedding feature vector를 제대로 만들어낸다면, 동일 클래스의 경우 feature space에서 vector간 서로 비슷하게 생겼을 것이다."

좀 너무 당연한가요? 그래도 이 당연한 가정에서 문제를 푸는 전략이 나옵니다. Loss 함수를 잘 디자인해서 labeled와 unlabeled data가 embedding feature space에서 서로 비슷한 녀석들끼리 가깝고 다른 녀석들끼리는 거리가 멀도록 하는 네트워크를 만드는 것이 목표입니다.

그러면 자연스럽게 embedding space에서 각각의 data 사이에 가깝다 멀다를 어떻게 정의해야할지에 대한 질문이 생깁니다. 이 논문에서는 이를 다음과 같이 정의했는데요.

각각, $A$: Labeled data, $B$: Unlabeled data의 feature space vector를 나타낼 때, $A_i$와 $B_j$ 사이의 유사도 $M){ij}$는 다음과 같이 내적 (inner product)로 정의합니다:
$$M_{ij}=A_i\cdot B_j$$ 여기까지만 보면 기존의 semi-supervised learning과 전혀 다를 것이 없지만, 앞으로 소개할 "연관성"이란 개념을 넣어 학습시키는 부분을 보시면 차이를 아실 수 있을 것 같습니다.

여기서부터가 중요합니다. 데이터가 들어왔을 때 embedding을 해주는 네트워크(초록색)가 있으면, 이로부터 feature space가 만들어지죠. 이런 embedding space에서 labeled data와 unlabeled 사이에 "연관성"을 정량화하기 위해 이 논문에서는 "walking"이라는 방법을 사용합니다.

위에 그림에서 보실 수 있듯이 labeled data의 feature vector에서 unlabeled data의 feature vector로 갔다(visit)가 다시 labeled data의 feature vector로 돌아왔을 때(walk) labeled data의 class가 바뀌지 않도록 제약을 주는 방식으로 loss 함수를 디자인합니다. 이 얘기들을 좀 더 수식화하여 나타내면 다음과 같습니다:

Transition Probability

$$\begin{align*}P^{ab}_{ij} = P(B_j|A_i) &:=(softmax_{cols}(M))_{ij} \\ &= exp(M_{ij})/\sum_{j'}exp(M_{ij'})\end{align*}$$

Round Trop Probability

$$\begin{align*}P^{aba}_{ij} &:= (P^{ab}P^{ba})_{ij} \\ &= \sum_{k}P^{ab}_{ik}P^{ba} _{kj}\end{align*}$$

Probability for correct walks

$$P(correct~walk) = \frac{1}{|A|}\sum_{i\sim j}P^{aba}_{ij}$$ where $i\sim j\Longleftrightarrow class(A_i) = class(A_j)$.

만약 우리의 가정대로 네트워크가 잘 학습을 하고 있다면 embedding feature space를 만들 때, walking하고 돌아왔을 때 여전히 class가 유지되야한다는 제약 때문에 자연스럽게 목표를 달성할 수 있겠다는 것이 key idea입니다.

이렇게 기본적으로 loss를 잘 디자인해서 연관성이라는 추상적인 개념을 실제 계산이 가능한 형태로 잘 녹여낸 것이 이 논문의 신선한 점이라 할 수 있습니다. 
$$\cal{L}_{total} =  \cal{L}_{walker}+\cal{L}_{visit}+\cal{L}_{classification}$$
총 세 부분으로 loss 함수가 나누어져있는 것을 알 수 있는데요 사실 이를 $\cal{L}_{walker}+\cal{L}_{visit}$와 $\cal{L}_{classification}$ 이렇게 두 부분으로 묶어 보시면 좀 더 이해가 편합니다. 

앞의 $\cal{L}_{walker}+\cal{L}_{visit}$ 부분이 오늘 소개할 association을 표현하는 loss 함수에 해당하구요 뒤의 $\cal{L}_{classification}$ 부분이 일반적으로 지도학습에서 사용하는 classfication loss가 되겠습니다. Label이 있는 녀석들에 대해서는 이런 loss 함수가 적용이 됩니다. (제가 예전에 정리해둔 DANN을 보신 분들이라면 좀 더 이해가 쉬울 수 있습니다.)

그래서 (아직까지는 어떻게 만들었는지는 모르지만) loss 함수를 잘 minimize하면 마법처럼 unlabeled data도 labeled data와 함께 잘 분류가 되게 하자는 것입니다. 그럼 하나하나 loss 함수를 이해해보겠습니다.

Walker loss


먼저 $\cal{L}_{walker}$입니다. 여기서 walk라는 이름은 제가 짐작하기로는 graph theory 쪽 용어를 가져온 것 같습니다. Graph theory 쪽 공부를 하다보면 data 하나를 점으로 보았을 때, 한 점에서 다른 점으로 가는 것을 "walk"라는 용어로 표현합니다.



위 그림이 walker loss에서 하고자 하는 일을 잘 설명해주고 있습니다. Labeled data의 class인 "개"에서 unlabeled loss를 방문한 후 다시 Labeled data로 돌아갔을 때 class가 여전히 "개"로 유지되길 바라는 것입니다. 여기서 주의하실 점은 돌아온 labeled data가 꼭 시작점의 labeled data와 정확히 일치할 필요는 없지만, class는 유지되기를 바라는 것입니다.

그래서 $\cal{L}_{walker}$에서는 만약 갔다가 돌아온 class가 달라지면 penalty를 주게 디자인 되어있습니다:$$\cal{L}_{walker}:=H(T,P^{aba})$$
여기서 $H$는 cross entropy이고 $T_{ij}$는 $class(A_i)=class(A_j)$일 때, $1/\#class(A_i)$이고 아닐 때는 0인 uniform distribution입니다. $P^{aba}$가 닮기를 바라는 $T$가 uniform distribution인 이유는 동일한 class로만 돌아오면 언제나 값이 같도록 하고 싶기 때문이죠. (동일한 class의 다른 이미지로 돌아왔다고 penalty를 주고 싶지 않을 것입니다.)

Visit loss


이제 visit loss입니다.
$$\cal{L}_{visit}:=H(V,P^{visit})$$
where, $ P_j^{visit} :=<P_{ij}^{ab}>_i, V:=1/|B|$

이 녀석이 하고자 하는 것도 그리 어렵지 않습니다 최대한 많은 sample을 다 골고루 봤으면 좋겠다는 것이죠. 대부분의 semi-supervised 방식에서는 자기가 고른 labeled data를 기준으로 가까이에 있는 unlabeled data만 보는데 그러지 말고 모두 다 보자는 것입니다. (그래서 $V$가 uniform distriction이지요. 그림을 통해보면 다음과 같습니다:



즉, 그림의 중간 동그라미 안에 들어가 있는 녀석들처럼 애매한 부분도 효과적으로 활용하고 싶다는 것이죠. 단, 여기서 unlabeled data가 너무 다른 경우 이 visit loss가 악영향을 끼치기 때문에 적절히 weight를 주는 것이 필요하다고 하네요. 

실험 결과 


MNIST


실험 결과도 상당히 놀랍습니다. 먼저 MNIST 결과에 대해서 학습 전후로 association이 어떻게 바뀌어 가는지 시각화해서 보여주면 다음과 같습니다:

Evolution of Associations

Top 부분이 학습을 시작해서 아주 약간 iteration을 돌렸을 때고 Bottom이 네트워크가 수렴한 후를 의미합니다.



이 실험 이후 MNIST에서 분류가 얼마나 잘 되었는지 확인을 해보면, test error가 0.96%로 매우 낮게 나왔는데요. 심지어 이렇게 틀린 경우도 설명이 가능하다고 얘기합니다. 우측에서 보이는 것이 confusion matrix인데요 여기서 틀린 부분을 좌하단에서 가져와 보여주면 사람이 봐도 왜 틀렸는지 이해가 갈만한 비슷비슷한 숫자들을 헷갈린 것이라고 애기하고 있습니다.

STL-10


저는 개인적으로 이 실험 결과가 매우 흥미롭습니다. STL-10은 RGB 이미지로 10개의 class가 있는 데이터셋인데요 약 5000개의 labeled 학습 이미지와 10만개의 unlabeled 학습 이미지가 있습니다. 재미있는 점은 이 unlabeled 이미지에 labeled 학습 이미지에 존재하지 않는 class의 이미지들도 있다는 것이죠.


그래서 결과를 보시면 매우 신기합니다. 위 두 줄이 labeled 이미지에 class가 있는 녀석의 nearest neighbor를 5개 뽑아본 것으로 상당히 잘 되는 것을 보실 수 있죠. 아래 두 줄이 제가 흥미롭게 생각한 부분입니다. Labeled 이미지 데이터셋에 존재하지 않는 class인 돌고래와 미어켓을 보여주니 네트워크가 내놓은 5개의 nearest neighbor인데요. 나름 비슷합니다. 돌고래의 삼각 지느러미 부분이 돛이나 비행기의 날개와 비슷하다 생각했는지 그런 이미지들이 같이 있고 미어켓은 신기하게도 동물들을 뽑아 온 것을 보실 수 있죠.

SVHN


이어서 SVHN에 대해 적용한 결과도 보여줍니다.

이 테이블이 보여주는 점은 자신들의 method가 unlabeled 데이터셋이 점점 많이 주어질 수록 에러율이 매우 내려간다는 것입니다. 즉, unlabeled 데이터로부터 실제로 연관 정보를 잘 뽑아내고 있다는 것이죠.

더욱 놀라운 점은 SVHN 데이터로 MNIST 데이터에 대한 Domain Adaptation 효과를 보여준 것입니다. 아래 테이블을 보시면 각각 SVHN에서만 학습시켰을 때, SVHN에서 학습시킨 후 MNIST로 Domain Adaptation 시켜줬을 때, MNIST에서만 학습시켰을 때, 세 가지 경우에서 MNIST 데이터셋에 대한 classification error를 알 수 있습니다. 

이를 자신들의 method와 Domain Adaptation에서 최근 유명했던 DANN, Domain Separation Nets 두 가지와 성능을 비교했는데 상당히 잘 되는 것을 볼 수 있습니다. 

Summary


지금까지 쭉 빠르게 논문을 살펴보았는데요 이 논문의 contibution을 정리해보자면 다음과 같습니다:
  1. 단순하지만 신선한 semi-supervised training method를 제안하였다.
  2. Tensorflow implentation이 있고 arbitray network architecture에 add-on처럼 범용적으로 붙여 사용할 수 있다. 
  3. SOTA methods에 비해 최대 64% 가까이의 성능 향상을 보였다. 
  4. Label이 매우 적을 때, SOTA methods를 매우 큰 차이로 이기는 결과를 확인하였다 (MNIST, SVHN)

게다가 심지어 ResNet 같은 복잡한 구조를 사용한 것도 아닌데 이런 결과가 나왔다는 것을 보면, 아직 성능이 더 개선이 될 여지가 충분하다는 것도 짚고 넘어가야할 것 같네요

이 논문을 읽고 아이디어가 새록새록 생기는데...일단 이 아이디어들은 나중에 졸업부터 하고 해야겠죠...지금 제 코가 석자라 ㅎㅎ 그래도 정말 재미있게 읽은 논문이었습니다.

다음 읽을거리



참고문헌:



댓글 5개:

  1. 항상 잘 읽고 갑니다. 감사합니다

    답글삭제
  2. 라벨링 비용이 큰 데이터를 다루고 있는데, 많은 도움이 되었습니다,

    답글삭제
  3. 제가 찾던 내용입니다. 정말 감사합니다.

    답글삭제