2017년 1월 13일 금요일

Backpropagation 설명 예제와 함께 완전히 이해하기

기계학습에 대해 관심을 갖고 신경망(Neural Networks, NN)을 처음 접하면서 가장 신비롭게 여겨졌던 부분이 바로 역전파(Backpropagation)를 통한 네트워크 학습(learning or training)이었습니다.
"신경망은 구체적으로 어떻게 학습이 되는거지?"
뭔가 직관적인 설명으로는 신경망이 계산한 답과 정답 사이의 오차를 계산하고 그 오차를 줄이도록 피드백을 주어 학습을 시킨다고 하는데....
1. 구체적으로 어떻게 그 변화 값들을 구하고 전파하는거지?
2. 그리고 그 방법이 "왜" 되는 것이지? 
그리고 역전파에 대해 어느 정도 머리로 이해하고 나서는...

생각보다 단순하고 매우 쉬운 수학만을 사용하네..
어짜피 Backpropagation은 이미 여러 라이브러리에서 잘 구현이 되어있어 사실상 내가 전혀 신경 쓸 일도 없을텐데 머리로 이해했으면 되었지 굳이 직접 손으로 풀어보고 코드를 짜보는 등 시간을 더 투자해야 할까? 등의 생각이 들었습니다.

먼저 마지막 의문에 대해서는 cs231n 수업 명강사로 유명한 Andrej Karpathy가 최근 쓴 글이 아주 잘 답변을 해주고 있죠; Yes you should understand backprop.
네 잘 알아야하고 코딩도 스스로 해봐야한다네요ㅎㅎ
저 글을 한 문장으로 요약하자면 다음과 같습니다:
"The problem with Backpropagation is that it is a leaky abstraction."
즉, 자칫하면 뉴럴 네트워크의 학습 과정을 추상적으로만 생각하고 지나가버리는 함정에 빠질 수 있다는 것!
실제로 구현해보지 않고 머리로만 이해한 사람은 신경망을 설계할 때, 단순히 layer를 쌓아주기만 하면 역전파(Backprop)가 "magically make them work on your data"라고 믿게 될 수 있으나...

실제로는 절대로 생각처럼 이상적으로 돌아가지만은 않고 여러가지 문제점들이 생길 수 있으며, 대표적으로 sigmoid나 tanh 함수가 saturate 되어 생기는 vanishing gradient 문제, dying ReLUs 문제 등을 자세히 소개해주고 있습니다. 관심있는 분들은 원글을 한 번 읽어보시길.

그럼 이 부분은 Andrej Karpathy에게 넘기고, 여기서는 첫 번째 의문에 대해 나름대로 정리해보겠습니다. 

물론 이미 같은 내용에 대한 답을 하기 위해 많은 자료들이 인터넷에 존재하고 여러 사람들이 다양한 관점으로 설명을 잘 해두었지만, 역시 무엇이든 제대로 이해하기 위해서는 스스로 시간을 투자해서 공부한 것을 나만의 언어와 논리로 정리하고 소화할 필요가 있는 것 같습니다.
이해를 위해 하는 것인데 너무 복잡한 문제로 시작하면 문제 자체보다 푸는 기술에만 집중할 수 있으니 가장 단순한 binary classification 문제에서 시작해보겠습니다. 

Classification problems with two classes (binary classification)

Binary (0 or 1) classification 문제를 풀 때, 기존의 신경망 구조는 마지막 output layer의 unit이 하나만 있어서 하나의 scalar 값을 계산하여 결과로 내보냅니다. 이 때, logistic function을 cross entropy loss function과 결합하여 사용하면 scalar 값이 0과 1 사이로 한정되면서 sum-of-squared loss와는 달리 두 가지 클래스 중 하나의 클래스에 대한 확률 값으로 자연스러운 해석이 가능해지죠.

Multiple, independent, binary classification 문제로의 확장은 아주 쉬운데, 각각의 target이 independent하므로 loglikelihood들을 더해주면 됩니다. 즉, 단일 unit이었던 때(single target)와 달리 마지막 layer에서 여러 unit을 갖습니다; 그림1
예시를 들자면 아마 이진 값으로 채워져있는 흑백 이미지 벡터를 생각하면 되지 않을까 싶네요. 그러면 각각의 픽셀에 대해 값이 있는지 없는지를 확률적으로 계산하는 문제가 되겠습니다. 

여기서 multi-class 문제로 좀 더 확장하는 것은 생각보다 간단하고 심지어 수식을 유도해보면 binary 문제와 거의 다를 것이 없습니다. 하지만 정작 이 글의 주제인 역전파에서 벗어날 수 있기 때문에 이 얘기는 다음에 따로 정리해보겠습니다. 

Backpropagation in 2-layer neural network

이제 이 문제를 가장 간단한 구조의 2-layer neural network로 풀기 위해 backpropagation을 계산해보며 실제로 네트워크가 어떻게 "학습"하는 지 확인해보겠습니다.

그림 1. 2-layer neural network

* Layer의 개수가 3개인데 왜 2-layer NN이라 하냐면 보통 neural network의 input layer는 세지 않기 때문에 그렇습니다.
** fully connected이기 때문에 각 층 사이에는 weight 값을 갖는 연결이 모든 조합에 대해 존재하고 편의상 그림에서는 생략해두었습니다. 
*** Hidden layer와 output layer의 activation function은 sigmoid 함수로 통일합니다. 


앞서 문제를 설정한대로 우리의 error function, E는 다음과 같습니다. 
$$E = -\sum_{i=1}^{nout}(t_i log(x_i)+(1-t_i)log(1-x_i))$$
where, $x_i = \frac{1}{1+e^{-s_i}}$, $s_i = \sum_{j=1}x_jw_{ji}$.

수식이 무엇을 의미하는지 직관적으로 이해하려면 극단적인 예제를 넣어보는게 가장 빠르죠ㅎㅎ 실제 class가 1일 때($t_i=1$), 신경망이 계산한 답이 1일 확률이 0이라는($x_i = 0$) output 결과가 나왔다고 해보고 E 값을 계산해보면 쉽게 무한대가 나오는 것을 알 수 있습니다.
반대로 제대로 확률 값이 1 즉, target $t_i$의 class가 1일 때, 실제로 1이라고 계산할 확률이 100%라고($x_i=1$) 신경망이 결과를 내었다면, E 값은 0으로 최소값을 갖습니다. 
즉, 우리의 목표는 E 값을 줄이는 것.
그럼 이제 이 오차 값을 줄이는 방향으로 네트워크를 업데이트 하면 되는데...
여기서 gradient descent라는 개념이 나옵니다. 
어떤 함수 $f$의 임의의 점 $x$에서의 구배(gradient)를 계산하면 $x$로부터 단위($\Delta x$)만큼 움직였을 때, $f$ 값이 변화하는 양을 표현하는 벡터가 나오고, 이 벡터는 항상 가장 가파르게 $f$ 값이 증가하는 방향을 가르키기 때문에 함수 $f$를 E로 두고 구배가 가르키는 반대방향(descent)으로 네트워크를 업데이트 해주면 됩니다. 

Data인 input을 바꾸는 것이 목적이 아니기 때문에 네트워크의 output 값에 영향을 주는 변수 중 우리가 수정할 것은 layer 간의 weight, $w$들입니다. 따라서 우리가 해야하는 것은 각 layer의 weight에 대한 E 함수의 gradient를 계산하고 이 변화량을 원래 w에 반대 방향 (negartive, -1)으로 더해주는 것이죠. 


신경망이 학습하는 것을 순서대로 나열해보면, Input layer로부터 각 layer를 지나치며 weight들을 곱하고, activation function을 지난 다음, 마지막 output layer에서 오차 값을 계산한 이후(feed forward), 업데이트를 위해 순차적으로 거꾸로(back) 돌아가야(propagation)합니다.
따라서 돌아갈 때 순서대로 가장 바깥 쪽 output과 hidden layer 사이의 weight부터 바꿔보겠습니다. $i$번째 output인 $x_i$에 대한 오차값에 대해 $w_{ji}$가 바뀌어야하는 정도를 계산해려면 다음과 같은 수식이 필요합니다:
$$\frac{\partial E}{\partial w_{ji}} = \frac{\partial E}{\partial x_i} \frac{\partial x_i}{\partial s_i} \frac{\partial s_i}{\partial w_{ji}}$$
  1. $\frac{\partial E}{\partial x_i} = \frac{-t_i}{x_i}+\frac{1-t_i}{1-x_i} = \frac{x_i-t_i}{x_i(1-x_i)}$
  2. $\frac{\partial x_i}{\partial s_i} = x_i(1-x_i)$
  3. $\frac{\partial s_i}{\partial w_{ji}} = x_j$
$$\therefore \frac{\partial E}{\partial w_{ji}} = (x_i-t_i)x_j$$
(여기서 순서를 잠시 짚고 넘어갈 필요가 있습니다. 1번과 2번을 곱하면 $\frac{\partial E}{\partial s_i}$가 나옵니다. 그 후 3번을 차례로 계산하여 곱하면 결과적으로 $\frac{\partial E}{\partial w_{ji}} = (x_i-t_i)x_j$가 됩니다. 먼저 $\frac{\partial E}{\partial s_i}$를 계산했다는 것을 기억해주시면 됩니다.)


같은 논리로 그 다음 hidden과 input layer 간의 weight, $w_{kj}$에 대해 변화량을 계산하면,
$$\frac{\partial E}{\partial w_{kj}} = \frac{\partial E}{\partial s_j} \frac{\partial s_j}{\partial w_{kj}}$$
  1. $\frac{\partial E}{\partial s_j} = \sum_{i=1}^{nout} \frac{\partial E}{\partial s_i}\frac{\partial s_i}{\partial x_j}\frac{\partial x_j}{\partial s_j} = \sum_{i=1}^{nout}(x_i-t_i)(w_{ji})(x_j(1-x_j))$ 
  2. $\frac{\partial s_j}{\partial w_{kj}} = x_k ~(\because s_j=\sum_{k=1}x_kw_{kj})$
$$\therefore \frac{\partial E}{\partial w_{kj}} = \sum_{i=1}^{nout}(x_i-t_i)(w_{ji})(x_j(1-x_j))(x_k)$$
여기서 $x_k$는 feed forward할 때도 $i=\{1,\cdots, nout\}$까지의 모든 값에 영향을 주기 때문에 backprop 때도 $\sum_{i=1}^{nout}$을 고려하여 계산해야합니다. 위 그림에서 빨간 화살표들을 보면 되겠습니다.

여기서도 $\frac{\partial E}{\partial s_j}$를 구한 것을 보실 수 있습니다. 앞에서도 강조했었는데 이걸 기억하면 위에서 유도한 식들을 general multiplayer network로 확장하는 것도 매우 쉽습니다. 각 layer의 $\frac{\partial E}{\partial w_{ij}}$를 구하고 싶다면, 항상 $\frac{\partial E}{\partial s_j}$를 계산하고, $\frac{\partial s_j}{\partial w_{kj}} = x_k$를 곱해주는 식이 되면 됩니다. 

참고문헌: https://www.ics.uci.edu/~pjsadows/notes.pdf
(이 문헌을 참고해서 공부하긴 했지만 보시다보면 수식 전개에서 약간의 오타가 있으니 감안하고 보셔야 합니다.)



댓글 28개:

  1. 와.. 공부하려고 여러 블로그 돌아다녀봤는데 여기가 가장 이해가 쉬웠어요

    답글삭제
    답글
    1. 뚜비님 댓글 감사합니다. 이해가 잘 되셨다니 다행입니다! 이해하기 쉬웠다고 하시니 열심히 풀어 쓴 보람이 있네요 ㅎㅎ

      삭제
    2. 인정합니다.. 정말 감사하네요 ~ 수학을 잘 못하는 공대생인지라 계산을 오래 하였으나 이해는 확실히 가네요!

      삭제
    3. 만약 L 까지 있다면, 파셜E/파셜Wlk 는 어떻게 되는건지 알 수있나요? ㅜㅜ

      삭제
  2. 정말 도움이 많이 되었습니다 . 궁금한 것이 있습니다. error function에서 "전체 training set 갯수" 에 대해서는 sigma가 없는것 같은데, 여기서는 training set 1개에 대해서 error function을 구하신건가요?

    답글삭제
    답글
    1. 안녕하세요. 댓글 감사합니다. sigma라 하심은 sum을 말씀하신건가요? 그런 의미라면 말씀하신 부분이 맞습니다.

      삭제
    2. 작성자가 댓글을 삭제했습니다.

      삭제
    3. 아 네 sum 을 얘기하는거였습니다. 그렇군요. ㅎㅎ 감사합니다

      삭제
  3. 잘 읽었습니다^^,
    편미분을 통해 특정 weight가 Cost (E함수의 아웃풋)에 얼마나 영향을 미치는지를 정량적으로 나타내는것이 Gradient 라고 생각하는데, Gradient의 반대되는 방향으로 움직이는 양은 learning rate라고 부르는 alpha 값으로 정해지는 건가요?

    답글삭제
    답글
    1. 안녕하세요 Unknown님. 정확히 이해하고 계십니다. Gradient는 Cost 함수에 가장 큰 영향을 줄 수 있는 방향이고 이를 alpha 값으로 조정합니다. 현재 최선의 방향이 꼭 이후에도 최선일 이유가 없으니까요.

      삭제
  4. 좋은 설명 감사드립니다:) 다름이 아니라 위의 글의 요지는 결국 Back Propagation이 Weight를 업데이트 한다는 것이잖아요.

    이런 업데이트된 Weight 값이 결국 Cost Funtion을 최소화 하는데 필요한 값들 이라고 생각합니다.

    다음 단계로 (Andrew 선생님은 강의에서) 역전파 이후에 SGD(혹은 Matlab 함수인 fminunc)를 써서 Cost Function을 최소화 하는 과정을 거치시던 데요.


    역전파가 Weight 업데이트 해주는 방법인데 이후 SGD를 또 해주는 이유는 무엇일까요? ㅜㅜㅜ (각각의 의미는 알겠는데 순서가 이해가 되지 않아서요.)


    너무 기본적인 궁금증 일지라도 시간을 내주시면 감사하겠습니다.

    답글삭제
    답글
    1. 안녕하세요 Unknown님. 이해를 잘 하고 계시는데 마지막에 헷갈리신 모양입니다. Backpropagation이라는 과정을 사용하는 것이 SGD입니다. SGD가 Stochastic Gradient Descent의 약자이므로 결국 Gradient를 계산해서 Cost가 증가하는 반대 (Descent) 방향으로 업데이트를 해주는데 모든 data에 대해 계산을 해주는 것은 현실적으로 불가능하니까 (계산량, 시간 등의 한계) 일부만 randomly sample 하여 (stochastic) 계산을 하는 것입니다.

      삭제
  5. 많은 도움이 되었습니다! 정말 감사합니다!

    답글삭제
    답글
    1. 덕윤님 안녕하세요 ㅎ 도움이 되셨다니 다행입니다 :)

      삭제
  6. 감사합니다. 막힌 부분을 콕 집어서 설명해 주셨네요.
    지금까지 본 설명 중에 제일 명쾌하고 쉬웠습니다.

    답글삭제
    답글
    1. 강보님 안녕하세요! 도움이 되셨다니 다행입니다.

      삭제
  7. backprop 이 궁금해 검색하다가 보게 되었습니다. 좋은 글 감사합니다!
    그런데 연쇄법칙을 적용할 때 각 층의 변수를 모두 x_j 로 번역하셔서 헷갈렸습니다 ㅎㅎ 원문과 같이 y, s, x 로 달리 두어야 할 듯 합니다.

    답글삭제
  8. 좋은 글 잘 읽었습니다! 저도 MistyMochi님의 의견에 동의합니다 :)

    답글삭제
  9. 지리노! 잘 봤습니다 ㅎ

    답글삭제
  10. 감사합니다! 좋은 글 여러번 읽어야 하겠네요😄

    답글삭제
  11. 이해가 가긴 했는데 예제를 통해 직접 하려고 하다보니 너무 어려운거같아서 질문드리고싶은데 혹시 메일 주소 가능하실까요?
    직접 예제를 구현하고자 해보는데 어려움을 겪고있습니다.

    답글삭제
    답글
    1. aaksj0605@gmail.com 으로 질문드리고싶습니다.

      삭제
    2. 가나다다 님 안녕하세요. https://cs231n.github.io/optimization-2/ 유명한 cs231n의 예제를 한 번 읽어보시는 것은 어떨까요? computational graph와 코드가 같이 있으니 도움이 되실 것 같습니다.

      삭제
  12. 작성자가 댓글을 삭제했습니다.

    답글삭제
  13. 작성자가 댓글을 삭제했습니다.

    답글삭제