2017년 6월 9일 금요일

RNN implementation using only Numpy (.ipynb with ENG, KOR explanations / Binary operation: addition)


Anyone can Learn To Code an LSTM-RNN in Python

[Refercence]
If you want more detailed background with explanations about RNN:
원문 링크 (Eng): https://iamtrask.github.io/2015/11/15/anyone-can-code-lstm/
번역 링크 (Kor): https://jaejunyoo.blogspot.com/2017/06/anyone-can-learn-to-code-LSTM-RNN-Python.html
깃헙 레포 (Kor, Eng): https://github.com/jaejun-yoo/RNN-implementation-using-Numpy-binary-digit-addition

[목표]

  • 간단한 toy code로 RNN을 이해한다.
  • RNN을 사용하여 이진수 더하기 연산을 학습시킨다.

[objective]

  • Understand RNN with a simple numpy implementation.
  • Train RNN for a binary opperation, e.g. addition.
  • Check if the trained RNN can be extended for the unseen data with longer digits (e.g. 8 bytes digits training -> 10 bytes digit test)


In [1]:
# Import libraries
import copy, numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
%matplotlib inline
In [2]:
# Utility functions
def sigmoid(x):
    output =1/(1+np.exp(-x))
    return output

def sigmoid_output_to_derivative(output):    
    return output*(1-output)
In [3]:
# Decide the maximum binary dimension (최대 이진수의 범위를 8 byte로 한정)
max_binary_dim = 8
largest_number = pow(2,max_binary_dim)
#print(2**8)
#print(pow(2,max_binary_dim))
#print(range(2**3))

Create binary lookup table

In [4]:
# Create binary lookup table (그저 편의상 만들뿐 굳이 이런 식으로 하지 않아도 됨)
# np.unpackbits e.g.
print(np.unpackbits(np.array([8], dtype = np.uint8)))
print("====================")
# 이진수로 만들 정수 값들 1~256을 list in list 형태로 만듬.
# e.g
# binary_gonna_be = np.array([range(largest_number)], dtype=np.uint8).T
# print(binary_gonna_be)

# 이런 식으로  binary lookup table 완성
binary = np.unpackbits(np.array([range(largest_number)], dtype=np.uint8).T, axis = 1)
print(binary.shape, binary)
print("====================")
int2binary = {}
for i in range(largest_number):
    int2binary[i] = binary[i]
print("lookup table test")
print(binary[3], int2binary[3])
#print(int2binary)
[0 0 0 0 1 0 0 0]
====================
(256, 8) [[0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 1]
 [0 0 0 ..., 0 1 0]
 ..., 
 [1 1 1 ..., 1 0 1]
 [1 1 1 ..., 1 1 0]
 [1 1 1 ..., 1 1 1]]
====================
lookup table test
[0 0 0 0 0 0 1 1] [0 0 0 0 0 0 1 1]

Initial parameter setting

In [5]:
alpha = 0.1 # learning rate
input_dim = 2 # 각 자리수끼리 더할 것이므로 서로 더할 두 이진수의 n번째 자리에 해당하는 digit 두 개가 input이 됨
hidden_dim = 16 # 바꾸면서 성능이 어떻게 변하는지 확인해보면서 놀아보자, You can vary this and see what happens
output_dim = 1 # output은 결국 n번째 자리의 digit 두 개가 합해서 나올 값이므로 one dim이 된다. e.g. 1(2) + 1(2) = 0(2) with overflow 1

# weight initialization
synapse_0 = 2*np.random.random((input_dim,hidden_dim))-1
synapse_1 = 2*np.random.random((hidden_dim,output_dim))-1
synapse_h = 2*np.random.random((hidden_dim,hidden_dim))-1

print(synapse_0.shape, synapse_1.shape, synapse_h.shape)
(2, 16) (16, 1) (16, 16)
In [6]:
# saving for updates and visualization
s0_update = np.zeros(synapse_0.shape) # s0_update = np.zeros_like(synapse_0)
s1_update = np.zeros(synapse_1.shape) 
sh_update = np.zeros(synapse_h.shape) 

overallError_history = list()
accuracy = list()
accuracy_history = list()
accuracy_count = 0

training!!

In [7]:
max_iter = 20000
for j in range(max_iter):
    # 랜덤하게 정수 두 개를 뽑은 후 binary lookup table에서 해당 이진수 가져오기.
    # Randomly pick two integers and change it to the binary representation
    a_int = np.random.randint(1,largest_number//2)
    a = int2binary[a_int]
    b_int = np.random.randint(1,largest_number//2)
    b = int2binary[b_int]
    # 실제 정답 계산 및 binary 벡터 저장.
    # Calculate the answer and save it as a binary form
    c_int = a_int + b_int
    c = int2binary[c_int]
    
    # RNN이 예측한 binary 합의 값 저장할 변수 선언.
    # Declare the variable for saving the prediction by RNN
    pred = np.zeros_like(c)
    
    overallError = 0
    
    output_layer_deltas = list()
    hidden_layer_values = list()
    hidden_layer_values.append(np.zeros(hidden_dim)) # dim: (1, 16)

    # feed forward !
    # 이진수의 가장 낮은 자리수부터 시작해야하므로 reversed로 for문 돌림.
    # As you have to calculate from the "first" position of the binary number, which stands for the lowest value, loop backward.
    # e.g. 
    # 10(2) + 11(2), for the first iteration: X = [[0,1]] y = [[1]]
    for position in reversed(range(max_binary_dim)):
        
        # RNN에 들어갈 input과 output label 이진수 값 가져오기
        # Take the input and output label binary values
        X = np.array([[a[position],b[position]]]) # dim: (1, 2), e.g. [[1,0]]
        y = np.array([[c[position]]]) # dim: (1, 1), e.g. [[1]]
        
        # hidden layer 계산하기 h_t = sigmoid(X*W_{hx} + h_{t-1}*W_{hh})
        hidden_layer = sigmoid(np.dot(X,synapse_0) + np.dot(hidden_layer_values[-1],synapse_h)) # dim: (1, 16)
        
        # output_layer 계산하기       
        output_layer = sigmoid(np.dot(hidden_layer,synapse_1)) # dim: (1, 1), e.g. [[0.47174173]]
        
        # error 계산
        output_layer_error = y-output_layer # dim: (1, 1) 
        
        # display를 위한 저장 (just for displying error curve)
        overallError += np.abs(output_layer_error[0]) # dim: (1, )          
        
        # 이 후 backpropagation에서 사용될 delta 값 미리 계산하여 저장
        # Save it for the later use in backpropagation step        
        output_layer_deltas.append((output_layer_error) * sigmoid_output_to_derivative(output_layer))        
        
        # 현재 자리수에 대한 예측값 저장
        # save the prediction by my model on this position
        pred[position] = np.round(output_layer[0][0])
        
        # 현재까지 계산된 hidden layer 저장
        # save the hidden layer by appending the values to the list
        hidden_layer_values.append(copy.deepcopy(hidden_layer)) 
    
    if (j%100 == 0):
        overallError_history.append(overallError[0])
    
    # 이제 backpropagation !
        
    # 맨 처음 시작할 때는 현재 시점보다 앞에 있는 hidden layer가 없으므로 delta 값이 0임.  
    # As RNN needs to consider the "future" hidden layer value to calculate the backpropagation and it does not have the 
    # value at the first time (at the end of the position where backpropagation starts), we have to initialize it with zeros
    future_hidden_layer_delta = np.zeros(hidden_dim)
    
    # backpropagation을 할 때는 이진수의 가장 앞자리수 시점부터 돌아와야 하므로 정상적인 for문
    # Now it should go "backward" which means an ordinary way in the for loop
    for position in range(max_binary_dim):
        
        # 필요한 값들 다시 불러오고 
        # bring what you needs for calculation
        X = np.array([[a[position],b[position]]])
        hidden_layer = hidden_layer_values[-position-1]
        prev_hidden_layer = hidden_layer_values[-position-2]
        
        # 현재 시점에서 output layer error로부터 돌아오는 gradient 값
        # Get the gradients flowing back from the error of my output at this position, or time step
        output_layer_delta = output_layer_deltas[-position-1]
        
        # 현재 시점의 hidden layer에 더해진 gradient를 계산하기 위해서는
        # 이전 시점의 hidden layer로부터 돌아오는 error gradient + 현재 시점 output layer로부터 돌아오는 error gradient
        # 이 둘의 합에 sigmoid의 derivative 계산해줘야 함
        # 이유: h_t = sigmoid(X*W_{hx} + h_{t-1}*W_{hh})를 역전파 하는 것을 생각하면 됨.
        # Important part! (Backpropagation)
        # Think about the feed forward step you have done before: h_t = sigmoid(X*W_{hx} + h_{t-1}*W_{hh})
        hidden_layer_delta = (np.dot(future_hidden_layer_delta,synapse_h.T) + np.dot(output_layer_delta,synapse_1.T)) \
                            * sigmoid_output_to_derivative(hidden_layer)
        
        # 8자리 모두를 다 계산한 후 gradient의 합을 한 번에 update 해준다. 
        # 이유: backprop이 아직 다 끝나지 않았는데 중간에 hidden layer의 value가 바뀌면 안됨
        # Save the updates until the for loop finishes calculation for every position
        # Hidden layer values must be changed ONLY AFTER backpropagation is fully done at every position.
        s1_update += np.atleast_2d(hidden_layer).T.dot(output_layer_delta)
        sh_update += np.atleast_2d(prev_hidden_layer).T.dot(hidden_layer_delta)
        s0_update += X.T.dot(hidden_layer_delta)
        
        # 다음 position으로 넘어가면 현재 hidden_layer_delta가 future step이 되므로 이를 넣어준다.
        # Preparation for the next step. Now the current hidden_layer_delta becomes the future hidden_layer_delta.
        future_hidden_layer_delta = hidden_layer_delta

    # weight 값들 update (learning rate를 곱하여)
    synapse_1 += s1_update*alpha
    synapse_0 += s0_update*alpha
    synapse_h += sh_update*alpha
    
    # update value initialization for the new training data (새로운 a,b training 이진수에 대해 계산을 해줘야하므로)
    s1_update *= 0
    s0_update *= 0    
    sh_update *= 0
    
    # accuracy 계산
    check = np.equal(pred,c)
    if np.sum(check) == max_binary_dim:
        accuracy_count += 1
    if (j%100 == 0):
        accuracy_history.append(accuracy_count)
        accuracy_count = 0
    
    
    if (j % 100 == 0):
        print ("Error:" + str(overallError))
        print ("Pred:" + str(pred))  # 예측값
        print ("True:" + str(c))  # 실제값

        final_check = np.equal(pred,c)
        print (np.sum(final_check) == max_binary_dim)

        out = 0

        for index, x in enumerate(reversed(pred)):
            out += x * pow(2, index)
        print (str(a_int) + " + " + str(b_int) + " = " + str(out))
        print ("------------")
Error:[ 3.97540736]
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 0 0 0 1 1 1]
False
10 + 61 = 0
------------
Error:[ 4.0325467]
Pred:[0 1 1 0 1 1 0 1]
True:[1 0 0 1 0 1 0 1]
False
111 + 38 = 109
------------
Error:[ 3.79424599]
Pred:[0 0 0 0 0 0 0 0]
True:[0 0 1 1 1 0 0 0]
False
55 + 1 = 0
------------
Error:[ 3.87395152]
Pred:[0 0 0 0 0 0 0 0]
True:[0 0 0 1 1 0 0 1]
False
20 + 5 = 0
------------
Error:[ 4.01458433]
Pred:[1 1 1 1 1 1 1 1]
True:[0 1 0 1 0 1 1 0]
False
12 + 74 = 255
------------
Error:[ 3.99141219]
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 1 0 0 0 1 0]
False
30 + 68 = 0
------------
Error:[ 3.97121353]
Pred:[0 0 0 0 1 0 0 0]
True:[0 0 1 0 0 0 0 1]
False
28 + 5 = 8
------------
Error:[ 4.02250446]
Pred:[1 1 1 1 1 1 1 0]
True:[0 1 0 1 1 0 1 1]
False
60 + 31 = 254
------------
Error:[ 4.04634296]
Pred:[1 1 0 1 0 1 0 1]
True:[0 1 1 1 1 1 1 1]
False
23 + 104 = 213
------------
Error:[ 3.87790083]
Pred:[1 0 1 1 1 1 1 0]
True:[1 1 0 1 0 1 1 0]
False
124 + 90 = 190
------------
Error:[ 3.8993887]
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 0 1 0 1 0 1]
False
75 + 10 = 0
------------
Error:[ 4.00422894]
Pred:[0 0 0 0 0 0 1 0]
True:[0 0 0 0 0 1 0 1]
False
4 + 1 = 2
------------
Error:[ 3.97174173]
Pred:[0 1 1 0 0 1 1 0]
True:[1 0 0 0 0 0 1 0]
False
15 + 115 = 102
------------
Error:[ 4.08583973]
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 1 0 1 1 0 0]
False
95 + 13 = 0
------------
Error:[ 4.20674656]
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 1 1 0 1 1 1]
False
28 + 91 = 0
------------
Error:[ 4.25291382]
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 1 1 1 1 1 0]
False
88 + 38 = 0
------------
Error:[ 3.53099877]
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 0 0 1 0 0 0]
False
40 + 32 = 0
------------
Error:[ 3.89231997]
Pred:[1 1 1 1 1 1 1 1]
True:[1 0 1 1 0 1 0 1]
False
79 + 102 = 255
------------
Error:[ 4.02295019]
Pred:[0 1 0 0 1 1 1 1]
True:[0 1 1 1 0 0 1 0]
False
75 + 39 = 79
------------
Error:[ 3.880903]
Pred:[1 1 1 1 1 1 0 1]
True:[0 1 1 0 1 1 0 1]
False
105 + 4 = 253
------------
Error:[ 3.92571431]
Pred:[1 1 1 1 1 1 1 1]
True:[1 0 0 1 1 0 1 1]
False
32 + 123 = 255
------------
Error:[ 3.7589035]
Pred:[0 1 0 0 1 0 1 0]
True:[1 0 0 0 1 1 1 0]
False
39 + 103 = 74
------------
Error:[ 3.87579648]
Pred:[1 1 1 1 1 0 1 0]
True:[1 0 1 0 1 0 1 1]
False
58 + 113 = 250
------------
Error:[ 4.19142604]
Pred:[0 0 0 0 1 1 0 0]
True:[0 1 1 1 0 1 0 1]
False
31 + 86 = 12
------------
Error:[ 3.86476286]
Pred:[0 0 1 1 0 0 1 0]
True:[0 0 1 1 1 1 1 0]
False
5 + 57 = 50
------------
Error:[ 4.05545722]
Pred:[1 1 0 1 1 1 1 1]
True:[1 0 0 1 0 1 0 1]
False
38 + 111 = 223
------------
Error:[ 4.11675483]
Pred:[1 1 1 1 1 1 1 1]
True:[1 0 1 0 0 0 0 1]
False
95 + 66 = 255
------------
Error:[ 3.90005213]
Pred:[1 0 0 0 0 0 1 1]
True:[1 1 1 0 1 1 1 1]
False
120 + 119 = 131
------------
Error:[ 4.16466363]
Pred:[0 0 1 1 0 0 1 1]
True:[1 0 0 1 0 1 1 0]
False
121 + 29 = 51
------------
Error:[ 3.76799453]
Pred:[0 0 0 1 1 0 0 0]
True:[0 1 1 0 1 0 0 0]
False
12 + 92 = 24
------------
Error:[ 3.74916647]
Pred:[1 1 1 1 0 1 1 0]
True:[0 1 1 1 0 1 1 0]
False
10 + 108 = 246
------------
Error:[ 4.19330056]
Pred:[0 1 1 1 1 1 1 0]
True:[1 0 0 0 0 1 1 0]
False
72 + 62 = 126
------------
Error:[ 3.63455126]
Pred:[1 1 0 1 1 1 1 0]
True:[0 1 0 1 0 1 1 0]
False
72 + 14 = 222
------------
Error:[ 3.79994751]
Pred:[1 1 0 0 1 1 1 0]
True:[1 1 0 0 1 1 0 0]
False
97 + 107 = 206
------------
Error:[ 3.62919201]
Pred:[1 0 0 0 1 0 1 0]
True:[1 0 1 1 1 0 1 0]
False
112 + 74 = 138
------------
Error:[ 3.73175955]
Pred:[0 0 1 0 0 0 0 0]
True:[0 1 0 1 0 1 1 0]
False
28 + 58 = 32
------------
Error:[ 3.37270867]
Pred:[0 1 0 1 1 0 1 1]
True:[0 1 0 1 0 0 1 1]
False
12 + 71 = 91
------------
Error:[ 3.42578141]
Pred:[1 0 0 1 1 1 1 1]
True:[1 0 0 1 1 1 1 1]
True
93 + 66 = 159
------------
Error:[ 3.96962954]
Pred:[0 0 1 1 0 1 1 0]
True:[0 1 0 0 1 0 0 0]
False
63 + 9 = 54
------------
Error:[ 3.20748485]
Pred:[1 1 1 0 1 1 1 1]
True:[0 1 1 0 1 1 1 1]
False
15 + 96 = 239
------------
Error:[ 3.07873713]
Pred:[0 0 1 1 1 1 1 1]
True:[0 0 1 0 1 1 1 1]
False
33 + 14 = 63
------------
Error:[ 3.27911476]
Pred:[0 1 1 0 0 0 1 1]
True:[0 1 0 1 0 0 1 1]
False
24 + 59 = 99
------------
Error:[ 3.91718754]
Pred:[1 1 0 0 1 0 1 1]
True:[1 0 1 1 0 0 0 1]
False
52 + 125 = 203
------------
Error:[ 2.84676144]
Pred:[1 1 0 1 1 0 1 0]
True:[1 1 0 1 1 0 1 0]
True
98 + 120 = 218
------------
Error:[ 3.56648456]
Pred:[0 1 0 0 1 1 1 0]
True:[0 1 0 1 0 0 0 0]
False
37 + 43 = 78
------------
Error:[ 2.61810298]
Pred:[1 1 1 0 0 1 1 1]
True:[1 1 0 0 0 0 1 1]
False
97 + 98 = 231
------------
Error:[ 2.34280248]
Pred:[1 0 0 1 1 0 0 0]
True:[1 1 0 1 1 0 0 0]
False
112 + 104 = 152
------------
Error:[ 3.83818822]
Pred:[0 1 1 1 1 0 1 0]
True:[1 0 0 0 0 0 1 0]
False
120 + 10 = 122
------------
Error:[ 2.57289539]
Pred:[0 1 1 0 0 1 1 0]
True:[0 1 1 0 0 1 1 0]
True
16 + 86 = 102
------------
Error:[ 3.31327062]
Pred:[1 1 0 0 1 0 0 0]
True:[1 0 0 0 1 0 1 0]
False
19 + 119 = 200
------------
Error:[ 2.69438107]
Pred:[1 0 0 1 1 0 0 1]
True:[1 1 0 1 0 0 0 1]
False
100 + 109 = 153
------------
Error:[ 2.95509023]
Pred:[1 0 1 1 1 0 0 1]
True:[1 0 1 1 0 0 0 1]
False
111 + 66 = 185
------------
Error:[ 2.24368312]
Pred:[0 1 0 1 1 1 0 1]
True:[0 1 0 1 1 1 0 1]
True
24 + 69 = 93
------------
Error:[ 3.63229369]
Pred:[1 0 1 1 0 1 0 0]
True:[1 1 0 0 0 1 0 0]
False
78 + 118 = 180
------------
Error:[ 2.52654005]
Pred:[1 1 0 0 1 1 1 1]
True:[1 1 0 0 1 1 1 1]
True
110 + 97 = 207
------------
Error:[ 2.83996122]
Pred:[1 1 1 0 1 0 0 1]
True:[1 1 1 0 1 0 0 1]
True
111 + 122 = 233
------------
Error:[ 1.07711253]
Pred:[0 1 1 0 0 0 1 0]
True:[0 1 1 0 0 0 1 0]
True
97 + 1 = 98
------------
Error:[ 1.51067066]
Pred:[1 0 0 1 0 0 0 0]
True:[1 0 0 1 0 0 0 0]
True
72 + 72 = 144
------------
Error:[ 1.24360068]
Pred:[0 1 1 0 1 1 1 0]
True:[0 1 1 0 1 1 1 0]
True
85 + 25 = 110
------------
Error:[ 3.40019024]
Pred:[1 0 1 1 1 1 1 1]
True:[0 1 1 1 1 1 1 1]
False
66 + 61 = 191
------------
Error:[ 1.63590052]
Pred:[1 0 0 0 0 0 1 1]
True:[1 0 0 0 0 0 1 1]
True
98 + 33 = 131
------------
Error:[ 2.72216248]
Pred:[1 0 0 0 0 0 0 0]
True:[1 0 0 0 1 0 0 0]
False
107 + 29 = 128
------------
Error:[ 0.53601483]
Pred:[0 1 0 1 0 1 0 1]
True:[0 1 0 1 0 1 0 1]
True
16 + 69 = 85
------------
Error:[ 1.53384966]
Pred:[1 0 1 0 1 1 1 0]
True:[1 0 1 0 1 1 1 0]
True
89 + 85 = 174
------------
Error:[ 1.23589765]
Pred:[0 1 0 0 1 1 1 0]
True:[0 1 0 0 1 1 1 0]
True
32 + 46 = 78
------------
Error:[ 1.52746406]
Pred:[1 1 0 1 1 0 0 1]
True:[1 1 0 1 1 0 0 1]
True
122 + 95 = 217
------------
Error:[ 0.8104674]
Pred:[0 1 1 1 1 1 1 0]
True:[0 1 1 1 1 1 1 0]
True
62 + 64 = 126
------------
Error:[ 1.41268677]
Pred:[1 0 1 1 1 0 0 0]
True:[1 0 1 1 1 0 0 0]
True
106 + 78 = 184
------------
Error:[ 1.58945965]
Pred:[0 1 1 1 1 0 1 1]
True:[0 1 1 1 1 0 1 1]
True
95 + 28 = 123
------------
Error:[ 1.57149055]
Pred:[1 0 1 1 0 1 1 1]
True:[1 0 1 1 0 1 1 1]
True
109 + 74 = 183
------------
Error:[ 0.99503431]
Pred:[1 1 1 1 0 1 1 1]
True:[1 1 1 1 0 1 1 1]
True
122 + 125 = 247
------------
Error:[ 1.23742639]
Pred:[1 1 0 0 0 1 1 1]
True:[1 1 0 0 0 1 1 1]
True
79 + 120 = 199
------------
Error:[ 1.00171615]
Pred:[1 0 1 0 0 1 1 1]
True:[1 0 1 0 0 1 1 1]
True
42 + 125 = 167
------------
Error:[ 0.88834079]
Pred:[1 0 0 1 0 1 1 0]
True:[1 0 0 1 0 1 1 0]
True
126 + 24 = 150
------------
Error:[ 0.62501817]
Pred:[1 1 0 0 0 0 0 1]
True:[1 1 0 0 0 0 0 1]
True
112 + 81 = 193
------------
Error:[ 0.68401851]
Pred:[0 1 0 1 0 1 0 1]
True:[0 1 0 1 0 1 0 1]
True
71 + 14 = 85
------------
Error:[ 0.35307906]
Pred:[0 1 0 1 0 1 1 0]
True:[0 1 0 1 0 1 1 0]
True
82 + 4 = 86
------------
Error:[ 1.27283888]
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
True
30 + 99 = 129
------------
Error:[ 0.59094215]
Pred:[0 1 0 1 0 0 1 0]
True:[0 1 0 1 0 0 1 0]
True
78 + 4 = 82
------------
Error:[ 0.58736554]
Pred:[1 0 0 1 1 1 0 1]
True:[1 0 0 1 1 1 0 1]
True
82 + 75 = 157
------------
Error:[ 0.38656398]
Pred:[1 0 0 0 1 1 0 1]
True:[1 0 0 0 1 1 0 1]
True
77 + 64 = 141
------------
Error:[ 0.43420519]
Pred:[0 0 1 1 0 0 1 1]
True:[0 0 1 1 0 0 1 1]
True
26 + 25 = 51
------------
Error:[ 0.50252171]
Pred:[1 0 1 0 0 1 1 0]
True:[1 0 1 0 0 1 1 0]
True
80 + 86 = 166
------------
Error:[ 0.66292361]
Pred:[1 0 0 1 1 0 0 1]
True:[1 0 0 1 1 0 0 1]
True
42 + 111 = 153
------------
Error:[ 0.61874068]
Pred:[1 1 0 0 1 1 0 0]
True:[1 1 0 0 1 1 0 0]
True
117 + 87 = 204
------------
Error:[ 0.39324392]
Pred:[1 0 0 0 1 1 1 0]
True:[1 0 0 0 1 1 1 0]
True
65 + 77 = 142
------------
Error:[ 0.45513889]
Pred:[0 1 0 1 0 0 1 1]
True:[0 1 0 1 0 0 1 1]
True
70 + 13 = 83
------------
Error:[ 0.27921542]
Pred:[0 1 0 1 1 0 1 1]
True:[0 1 0 1 1 0 1 1]
True
90 + 1 = 91
------------
Error:[ 0.62198489]
Pred:[1 1 1 1 0 1 0 1]
True:[1 1 1 1 0 1 0 1]
True
127 + 118 = 245
------------
Error:[ 0.50229076]
Pred:[1 0 0 1 0 1 0 1]
True:[1 0 0 1 0 1 0 1]
True
57 + 92 = 149
------------
Error:[ 0.58217015]
Pred:[0 1 1 1 0 0 0 0]
True:[0 1 1 1 0 0 0 0]
True
21 + 91 = 112
------------
Error:[ 0.50710971]
Pred:[1 0 1 1 1 0 1 0]
True:[1 0 1 1 1 0 1 0]
True
115 + 71 = 186
------------
Error:[ 0.51719454]
Pred:[0 1 1 1 0 0 0 1]
True:[0 1 1 1 0 0 0 1]
True
18 + 95 = 113
------------
Error:[ 0.47046102]
Pred:[1 1 1 0 1 1 1 0]
True:[1 1 1 0 1 1 1 0]
True
116 + 122 = 238
------------
Error:[ 0.51538252]
Pred:[0 1 0 1 0 0 0 1]
True:[0 1 0 1 0 0 0 1]
True
47 + 34 = 81
------------
Error:[ 0.48485176]
Pred:[1 0 1 1 0 0 0 1]
True:[1 0 1 1 0 0 0 1]
True
73 + 104 = 177
------------
Error:[ 0.45145853]
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
True
77 + 93 = 170
------------
Error:[ 0.38580434]
Pred:[1 0 0 1 1 1 0 1]
True:[1 0 0 1 1 1 0 1]
True
116 + 41 = 157
------------
Error:[ 0.25024716]
Pred:[1 0 0 1 0 0 0 1]
True:[1 0 0 1 0 0 0 1]
True
80 + 65 = 145
------------
Error:[ 0.18765659]
Pred:[0 0 1 0 1 0 1 1]
True:[0 0 1 0 1 0 1 1]
True
2 + 41 = 43
------------
Error:[ 0.48164219]
Pred:[1 0 1 1 1 0 1 1]
True:[1 0 1 1 1 0 1 1]
True
124 + 63 = 187
------------
Error:[ 0.52533063]
Pred:[1 0 0 1 1 0 1 1]
True:[1 0 0 1 1 0 1 1]
True
52 + 103 = 155
------------
Error:[ 0.14392675]
Pred:[0 1 0 1 0 0 1 0]
True:[0 1 0 1 0 0 1 0]
True
81 + 1 = 82
------------
Error:[ 0.2536919]
Pred:[0 0 0 1 1 0 0 1]
True:[0 0 0 1 1 0 0 1]
True
20 + 5 = 25
------------
Error:[ 0.4577056]
Pred:[1 0 1 0 0 0 0 1]
True:[1 0 1 0 0 0 0 1]
True
90 + 71 = 161
------------
Error:[ 0.35454413]
Pred:[0 1 1 0 0 0 0 1]
True:[0 1 1 0 0 0 0 1]
True
72 + 25 = 97
------------
Error:[ 0.3027905]
Pred:[0 1 0 1 0 0 0 0]
True:[0 1 0 1 0 0 0 0]
True
11 + 69 = 80
------------
Error:[ 0.32028173]
Pred:[0 1 0 0 0 1 0 0]
True:[0 1 0 0 0 1 0 0]
True
49 + 19 = 68
------------
Error:[ 0.47741899]
Pred:[1 0 1 1 0 0 1 1]
True:[1 0 1 1 0 0 1 1]
True
70 + 109 = 179
------------
Error:[ 0.19099271]
Pred:[0 1 0 1 1 0 1 1]
True:[0 1 0 1 1 0 1 1]
True
18 + 73 = 91
------------
Error:[ 0.44434423]
Pred:[1 1 0 1 0 0 1 1]
True:[1 1 0 1 0 0 1 1]
True
111 + 100 = 211
------------
Error:[ 0.34579419]
Pred:[1 0 0 1 1 0 0 0]
True:[1 0 0 1 1 0 0 0]
True
73 + 79 = 152
------------
Error:[ 0.16166863]
Pred:[0 1 1 1 0 1 1 0]
True:[0 1 1 1 0 1 1 0]
True
113 + 5 = 118
------------
Error:[ 0.18488997]
Pred:[0 0 1 1 0 1 1 1]
True:[0 0 1 1 0 1 1 1]
True
2 + 53 = 55
------------
Error:[ 0.39762466]
Pred:[1 0 0 1 0 1 1 0]
True:[1 0 0 1 0 1 1 0]
True
127 + 23 = 150
------------
Error:[ 0.30393507]
Pred:[0 1 1 0 1 1 1 1]
True:[0 1 1 0 1 1 1 1]
True
97 + 14 = 111
------------
Error:[ 0.43896228]
Pred:[1 0 1 0 1 1 0 0]
True:[1 0 1 0 1 1 0 0]
True
78 + 94 = 172
------------
Error:[ 0.19477698]
Pred:[0 1 1 0 0 1 1 0]
True:[0 1 1 0 0 1 1 0]
True
36 + 66 = 102
------------
Error:[ 0.384797]
Pred:[1 0 1 1 0 1 0 0]
True:[1 0 1 1 0 1 0 0]
True
63 + 117 = 180
------------
Error:[ 0.24173037]
Pred:[0 1 1 1 1 1 1 1]
True:[0 1 1 1 1 1 1 1]
True
68 + 59 = 127
------------
Error:[ 0.40062057]
Pred:[0 1 0 1 0 1 0 1]
True:[0 1 0 1 0 1 0 1]
True
27 + 58 = 85
------------
Error:[ 0.4056001]
Pred:[0 1 0 0 0 1 0 1]
True:[0 1 0 0 0 1 0 1]
True
63 + 6 = 69
------------
Error:[ 0.41176447]
Pred:[1 0 0 1 1 1 1 1]
True:[1 0 0 1 1 1 1 1]
True
127 + 32 = 159
------------
Error:[ 0.45722834]
Pred:[1 1 0 0 1 0 0 0]
True:[1 1 0 0 1 0 0 0]
True
106 + 94 = 200
------------
Error:[ 0.26925754]
Pred:[0 1 1 0 1 1 1 0]
True:[0 1 1 0 1 1 1 0]
True
80 + 30 = 110
------------
Error:[ 0.18420419]
Pred:[0 0 0 1 1 1 0 0]
True:[0 0 0 1 1 1 0 0]
True
23 + 5 = 28
------------
Error:[ 0.42250863]
Pred:[1 1 0 1 0 0 1 0]
True:[1 1 0 1 0 0 1 0]
True
102 + 108 = 210
------------
Error:[ 0.15972739]
Pred:[0 0 1 1 1 0 1 0]
True:[0 0 1 1 1 0 1 0]
True
18 + 40 = 58
------------
Error:[ 0.36053754]
Pred:[1 0 0 1 0 1 0 0]
True:[1 0 0 1 0 1 0 0]
True
95 + 53 = 148
------------
Error:[ 0.12271426]
Pred:[0 0 0 0 1 1 0 1]
True:[0 0 0 0 1 1 0 1]
True
5 + 8 = 13
------------
Error:[ 0.29394636]
Pred:[0 1 1 0 0 1 1 0]
True:[0 1 1 0 0 1 1 0]
True
15 + 87 = 102
------------
Error:[ 0.43435847]
Pred:[1 0 1 0 0 0 0 0]
True:[1 0 1 0 0 0 0 0]
True
101 + 59 = 160
------------
Error:[ 0.25130702]
Pred:[0 0 1 0 0 1 0 0]
True:[0 0 1 0 0 1 0 0]
True
31 + 5 = 36
------------
Error:[ 0.37297311]
Pred:[0 1 0 1 0 0 1 0]
True:[0 1 0 1 0 0 1 0]
True
54 + 28 = 82
------------
Error:[ 0.23398697]
Pred:[0 0 0 1 1 0 1 0]
True:[0 0 0 1 1 0 1 0]
True
12 + 14 = 26
------------
Error:[ 0.33698497]
Pred:[0 1 1 1 1 0 1 0]
True:[0 1 1 1 1 0 1 0]
True
110 + 12 = 122
------------
Error:[ 0.32293045]
Pred:[0 1 0 1 0 0 0 0]
True:[0 1 0 1 0 0 0 0]
True
74 + 6 = 80
------------
Error:[ 0.29383664]
Pred:[0 1 0 0 1 0 0 1]
True:[0 1 0 0 1 0 0 1]
True
28 + 45 = 73
------------
Error:[ 0.29335243]
Pred:[1 0 0 1 1 0 1 0]
True:[1 0 0 1 1 0 1 0]
True
53 + 101 = 154
------------
Error:[ 0.37123225]
Pred:[1 1 1 0 1 1 0 0]
True:[1 1 1 0 1 1 0 0]
True
126 + 110 = 236
------------
Error:[ 0.26784038]
Pred:[1 1 0 1 1 1 1 0]
True:[1 1 0 1 1 1 1 0]
True
122 + 100 = 222
------------
Error:[ 0.34925971]
Pred:[0 0 1 1 0 1 1 1]
True:[0 0 1 1 0 1 1 1]
True
15 + 40 = 55
------------
Error:[ 0.2196544]
Pred:[0 1 1 0 1 1 1 0]
True:[0 1 1 0 1 1 1 0]
True
25 + 85 = 110
------------
Error:[ 0.14428875]
Pred:[0 1 0 0 1 0 1 0]
True:[0 1 0 0 1 0 1 0]
True
7 + 67 = 74
------------
Error:[ 0.30459088]
Pred:[0 1 0 0 0 0 0 0]
True:[0 1 0 0 0 0 0 0]
True
63 + 1 = 64
------------
Error:[ 0.1506828]
Pred:[0 1 1 1 1 0 1 1]
True:[0 1 1 1 1 0 1 1]
True
64 + 59 = 123
------------
Error:[ 0.27649987]
Pred:[0 1 0 1 1 1 0 0]
True:[0 1 0 1 1 1 0 0]
True
50 + 42 = 92
------------
Error:[ 0.27844325]
Pred:[1 1 0 0 0 0 1 0]
True:[1 1 0 0 0 0 1 0]
True
77 + 117 = 194
------------
Error:[ 0.29936468]
Pred:[0 1 0 1 0 0 0 0]
True:[0 1 0 1 0 0 0 0]
True
33 + 47 = 80
------------
Error:[ 0.17910203]
Pred:[1 0 0 1 1 0 1 1]
True:[1 0 0 1 1 0 1 1]
True
82 + 73 = 155
------------
Error:[ 0.19257083]
Pred:[1 0 1 1 1 1 0 1]
True:[1 0 1 1 1 1 0 1]
True
105 + 84 = 189
------------
Error:[ 0.12016148]
Pred:[0 1 0 0 1 1 1 0]
True:[0 1 0 0 1 1 1 0]
True
11 + 67 = 78
------------
Error:[ 0.26722366]
Pred:[1 0 1 0 1 1 1 0]
True:[1 0 1 0 1 1 1 0]
True
112 + 62 = 174
------------
Error:[ 0.28009711]
Pred:[1 0 1 0 0 1 0 1]
True:[1 0 1 0 0 1 0 1]
True
108 + 57 = 165
------------
Error:[ 0.22371447]
Pred:[0 1 1 1 0 0 1 0]
True:[0 1 1 1 0 0 1 0]
True
15 + 99 = 114
------------
Error:[ 0.27980071]
Pred:[1 0 1 0 0 1 1 1]
True:[1 0 1 0 0 1 1 1]
True
113 + 54 = 167
------------
Error:[ 0.22918159]
Pred:[0 0 0 1 1 1 1 1]
True:[0 0 0 1 1 1 1 1]
True
1 + 30 = 31
------------
Error:[ 0.31937504]
Pred:[1 0 1 1 0 0 1 1]
True:[1 0 1 1 0 0 1 1]
True
102 + 77 = 179
------------
Error:[ 0.22852688]
Pred:[1 1 1 0 0 1 1 0]
True:[1 1 1 0 0 1 1 0]
True
112 + 118 = 230
------------
Error:[ 0.19959135]
Pred:[1 1 0 1 0 0 1 0]
True:[1 1 0 1 0 0 1 0]
True
121 + 89 = 210
------------
Error:[ 0.12430546]
Pred:[0 1 0 0 1 1 0 0]
True:[0 1 0 0 1 1 0 0]
True
4 + 72 = 76
------------
Error:[ 0.22408818]
Pred:[0 1 0 0 0 1 1 0]
True:[0 1 0 0 0 1 1 0]
True
48 + 22 = 70
------------
Error:[ 0.20221729]
Pred:[1 0 1 0 0 1 1 0]
True:[1 0 1 0 0 1 1 0]
True
115 + 51 = 166
------------
Error:[ 0.17104793]
Pred:[1 0 0 0 1 0 1 0]
True:[1 0 0 0 1 0 1 0]
True
71 + 67 = 138
------------
Error:[ 0.29566632]
Pred:[1 0 0 0 0 1 1 1]
True:[1 0 0 0 0 1 1 1]
True
125 + 10 = 135
------------
Error:[ 0.26591579]
Pred:[1 0 1 0 1 0 0 1]
True:[1 0 1 0 1 0 0 1]
True
95 + 74 = 169
------------
Error:[ 0.23643373]
Pred:[0 0 1 0 0 1 1 0]
True:[0 0 1 0 0 1 1 0]
True
28 + 10 = 38
------------
Error:[ 0.29028257]
Pred:[1 0 1 1 0 0 1 0]
True:[1 0 1 1 0 0 1 0]
True
55 + 123 = 178
------------
Error:[ 0.23127843]
Pred:[1 1 0 0 1 1 0 1]
True:[1 1 0 0 1 1 0 1]
True
88 + 117 = 205
------------
Error:[ 0.21379392]
Pred:[1 0 0 0 0 1 1 1]
True:[1 0 0 0 0 1 1 1]
True
98 + 37 = 135
------------
Error:[ 0.28917566]
Pred:[1 0 1 0 1 0 0 1]
True:[1 0 1 0 1 0 0 1]
True
111 + 58 = 169
------------
Error:[ 0.25232534]
Pred:[1 1 0 1 1 1 0 0]
True:[1 1 0 1 1 1 0 0]
True
95 + 125 = 220
------------
Error:[ 0.23058793]
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
True
88 + 82 = 170
------------
Error:[ 0.30600955]
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
True
99 + 30 = 129
------------
Error:[ 0.24408132]
Pred:[0 1 0 1 1 0 0 0]
True:[0 1 0 1 1 0 0 0]
True
52 + 36 = 88
------------
Error:[ 0.26199122]
Pred:[1 0 0 1 1 0 1 0]
True:[1 0 0 1 1 0 1 0]
True
127 + 27 = 154
------------
Error:[ 0.2308597]
Pred:[1 1 0 1 1 1 0 0]
True:[1 1 0 1 1 1 0 0]
True
106 + 114 = 220
------------
Error:[ 0.26871153]
Pred:[1 0 0 0 1 0 0 0]
True:[1 0 0 0 1 0 0 0]
True
21 + 115 = 136
------------
Error:[ 0.14836389]
Pred:[0 1 1 1 0 1 1 1]
True:[0 1 1 1 0 1 1 1]
True
99 + 20 = 119
------------
Error:[ 0.26056678]
Pred:[1 1 1 1 0 1 1 1]
True:[1 1 1 1 0 1 1 1]
True
126 + 121 = 247
------------
Error:[ 0.21699123]
Pred:[0 1 0 0 0 1 1 1]
True:[0 1 0 0 0 1 1 1]
True
52 + 19 = 71
------------
Error:[ 0.15114712]
Pred:[0 1 1 1 0 0 1 1]
True:[0 1 1 1 0 0 1 1]
True
51 + 64 = 115
------------
Error:[ 0.23657825]
Pred:[1 0 0 1 1 1 1 1]
True:[1 0 0 1 1 1 1 1]
True
54 + 105 = 159
------------
Error:[ 0.21836464]
Pred:[1 1 1 1 1 0 1 0]
True:[1 1 1 1 1 0 1 0]
True
125 + 125 = 250
------------
Error:[ 0.11765047]
Pred:[0 1 1 0 1 0 1 1]
True:[0 1 1 0 1 0 1 1]
True
42 + 65 = 107
------------
Error:[ 0.22654752]
Pred:[0 1 1 1 1 1 1 1]
True:[0 1 1 1 1 1 1 1]
True
97 + 30 = 127
------------
Error:[ 0.09659835]
Pred:[0 1 1 1 0 1 0 1]
True:[0 1 1 1 0 1 0 1]
True
32 + 85 = 117
------------
Error:[ 0.20630273]
Pred:[1 1 1 1 0 0 1 1]
True:[1 1 1 1 0 0 1 1]
True
120 + 123 = 243
------------
Error:[ 0.19706751]
Pred:[0 1 1 1 0 1 0 0]
True:[0 1 1 1 0 1 0 0]
True
21 + 95 = 116
------------
Error:[ 0.22943276]
Pred:[1 0 1 0 0 0 1 0]
True:[1 0 1 0 0 0 1 0]
True
104 + 58 = 162
------------
Error:[ 0.23257017]
Pred:[0 1 1 0 0 0 1 1]
True:[0 1 1 0 0 0 1 1]
True
22 + 77 = 99
------------
Error:[ 0.15080149]
Pred:[0 1 0 1 1 1 1 0]
True:[0 1 0 1 1 1 1 0]
True
45 + 49 = 94
------------
Error:[ 0.12567795]
Pred:[0 0 1 1 0 1 1 1]
True:[0 0 1 1 0 1 1 1]
True
38 + 17 = 55
------------
Error:[ 0.28120981]
Pred:[1 1 0 0 0 1 1 0]
True:[1 1 0 0 0 1 1 0]
True
108 + 90 = 198
------------
Error:[ 0.32838608]
Pred:[1 0 0 0 1 0 0 0]
True:[1 0 0 0 1 0 0 0]
True
110 + 26 = 136
------------
Error:[ 0.22691487]
Pred:[1 1 0 0 0 1 1 1]
True:[1 1 0 0 0 1 1 1]
True
72 + 127 = 199
------------
Error:[ 0.09966136]
Pred:[0 0 1 0 0 1 0 0]
True:[0 0 1 0 0 1 0 0]
True
4 + 32 = 36
------------
Error:[ 0.23196088]
Pred:[0 1 0 1 0 1 1 1]
True:[0 1 0 1 0 1 1 1]
True
28 + 59 = 87
------------
Error:[ 0.14169074]
Pred:[0 0 0 1 1 0 1 0]
True:[0 0 0 1 1 0 1 0]
True
19 + 7 = 26
------------
Error:[ 0.11328021]
Pred:[0 1 1 0 1 1 0 0]
True:[0 1 1 0 1 1 0 0]
True
100 + 8 = 108
------------
In [8]:
#print(overallError_history)
x_range = range(max_iter//100)
plt.plot(x_range,overallError_history,'r-')
plt.ylabel('overallError')
plt.show()

plt.plot(x_range,accuracy_history,'b-')
plt.ylabel('accuracy')
plt.show()

max binary dimension 이상의 수에서도 될까?

Is it possible to use RNN trained on the 8 bytes digits to 10 bytes or more?

  • 8 bytes에서 학습된 RNN이 10bytes와 같이 그 이상의 자릿수를 가진 이진수에서도 잘 동작하는지 확인하자!
In [9]:
# Test for the codes (garbage)
# ====================================================== #
#int('{0:09b}'.format(6))
#list(format(6, "08b"))
#results = list(map(int, list(format(6, "08b"))))
#results
#str(10)+"b"
# ====================================================== #

# create a binary digit over 8 bytes
max_binary_dim = 10
largest_number = pow(2,max_binary_dim)
digit_key = "0"+str(max_binary_dim)+"b"
np.array(list(map(int,list(format(6, digit_key)))))
Out[9]:
array([0, 0, 0, 0, 0, 0, 0, 1, 1, 0])
In [10]:
# initialization
overallError_history = list()
accuracy = list()
accuracy_history = list()
accuracy_count = 0

Everything is exactly the same except the digit length tested

  • This time, we only need to calculate the feed-forward step to see the prediction.
In [11]:
max_iter = 10000
for j in range(max_iter):
    # 랜덤하게 정수 두 개를 뽑은 후 binary lookup table에서 해당 이진수 가져오기.
    a_int = np.random.randint(1,largest_number//2)
    a = np.array(list(map(int, list(format(a_int, digit_key)))))
    b_int = np.random.randint(1,largest_number//2)
    b = np.array(list(map(int, list(format(b_int, digit_key)))))
    # 실제 정답 계산 및 binary 벡터 저장.
    c_int = a_int + b_int
    c = np.array(list(map(int, list(format(c_int, digit_key)))))
    
    # RNN이 예측한 binary 합의 값 저장할 변수 선언.
    pred = np.zeros_like(c)
    
    overallError = 0
    
    output_layer_deltas = list()
    hidden_layer_values = list()
    hidden_layer_values.append(np.zeros(hidden_dim)) # dim: (1, 16)

    # feed forward !
    # 이진수의 가장 낮은 자리수부터 시작해야하므로 reversed로 for문 돌림.
    for position in reversed(range(max_binary_dim)):
        
        # RNN에 들어갈 input과 output label 이진수 값 가져오기
        X = np.array([[a[position],b[position]]]) # dim: (1, 2), e.g. [[1,0]]
        y = np.array([[c[position]]]) # dim: (1, 1), e.g. [[1]]
        
        # hidden layer 계산하기 h_t = sigmoid(X*W_{hx} + h_{t-1}*W_{hh})
        hidden_layer = sigmoid(np.dot(X,synapse_0) + np.dot(hidden_layer_values[-1],synapse_h)) # dim: (1, 16)
        
        # output_layer 계산하기       
        output_layer = sigmoid(np.dot(hidden_layer,synapse_1)) # dim: (1, 1), e.g. [[0.47174173]]
        
        # error 값 계산
        output_layer_error = y-output_layer # dim: (1, 1) 
        
        # display를 위한 저장
        overallError += np.abs(output_layer_error[0]) # dim: (1, )          
        
        # 이 후 backpropagation에서 사용될 delta 값 미리 계산하여 저장
        output_layer_deltas.append((output_layer_error) * sigmoid_output_to_derivative(output_layer))        
        
        # 현재 자리수에 대한 예측값 저장
        pred[position] = np.round(output_layer[0][0])
        
        # 현재까지 계산된 hidden layer 저장
        hidden_layer_values.append(copy.deepcopy(hidden_layer)) 
    
    if (j%100 == 0):
        overallError_history.append(overallError[0])
    
    
    # accuracy 계산
    check = np.equal(pred,c)
    if np.sum(check) == max_binary_dim:
        accuracy_count += 1
    if (j%100 == 0):
        accuracy_history.append(accuracy_count)
        accuracy_count = 0
    
    
    if (j % 100 == 0):
        print ("Error:" + str(overallError))
        print ("Pred:" + str(pred))  # 예측값
        print ("True:" + str(c))  # 실제값

        final_check = np.equal(pred,c)
        print (np.sum(final_check) == max_binary_dim)

        out = 0

        for index, x in enumerate(reversed(pred)):
            out += x * pow(2, index)
        print (str(a_int) + " + " + str(b_int) + " = " + str(out))
        print ("------------")
Error:[ 0.17447237]
Pred:[1 1 0 0 0 1 1 1 1 0]
True:[1 1 0 0 0 1 1 1 1 0]
True
384 + 414 = 798
------------
Error:[ 0.27651708]
Pred:[0 1 1 0 1 1 0 1 0 1]
True:[0 1 1 0 1 1 0 1 0 1]
True
246 + 191 = 437
------------
Error:[ 0.18742734]
Pred:[1 0 1 1 1 0 0 1 1 1]
True:[1 0 1 1 1 0 0 1 1 1]
True
419 + 324 = 743
------------
Error:[ 0.26819747]
Pred:[1 0 1 1 1 0 1 1 1 1]
True:[1 0 1 1 1 0 1 1 1 1]
True
312 + 439 = 751
------------
Error:[ 0.15437666]
Pred:[1 1 0 1 1 0 1 0 0 1]
True:[1 1 0 1 1 0 1 0 0 1]
True
416 + 457 = 873
------------
Error:[ 0.25902133]
Pred:[1 1 0 0 1 1 1 0 1 1]
True:[1 1 0 0 1 1 1 0 1 1]
True
333 + 494 = 827
------------
Error:[ 0.21948908]
Pred:[1 0 1 0 0 1 1 0 0 1]
True:[1 0 1 0 0 1 1 0 0 1]
True
312 + 353 = 665
------------
Error:[ 0.30909945]
Pred:[1 0 0 1 1 0 1 1 1 1]
True:[1 0 0 1 1 0 1 1 1 1]
True
503 + 120 = 623
------------
Error:[ 0.29611804]
Pred:[1 1 1 0 1 1 0 0 0 1]
True:[1 1 1 0 1 1 0 0 0 1]
True
486 + 459 = 945
------------
Error:[ 0.3206195]
Pred:[1 0 0 0 0 0 1 0 0 1]
True:[1 0 0 0 0 0 1 0 0 1]
True
397 + 124 = 521
------------
Error:[ 0.3040861]
Pred:[1 0 1 0 0 1 0 0 0 1]
True:[1 0 1 0 0 1 0 0 0 1]
True
278 + 379 = 657
------------
Error:[ 0.20313793]
Pred:[1 0 0 1 1 0 1 1 0 0]
True:[1 0 0 1 1 0 1 1 0 0]
True
171 + 449 = 620
------------
Error:[ 0.32864245]
Pred:[1 0 1 0 0 1 0 0 1 0]
True:[1 0 1 0 0 1 0 0 1 0]
True
499 + 159 = 658
------------
Error:[ 0.2906054]
Pred:[0 1 0 1 1 0 0 0 0 0]
True:[0 1 0 1 1 0 0 0 0 0]
True
211 + 141 = 352
------------
Error:[ 0.15756375]
Pred:[1 0 0 0 1 0 1 1 1 0]
True:[1 0 0 0 1 0 1 1 1 0]
True
292 + 266 = 558
------------
Error:[ 0.30643925]
Pred:[0 1 1 0 1 0 0 0 1 0]
True:[0 1 1 0 1 0 0 0 1 0]
True
316 + 102 = 418
------------
Error:[ 0.24416273]
Pred:[0 0 1 1 0 0 0 1 1 1]
True:[0 0 1 1 0 0 0 1 1 1]
True
125 + 74 = 199
------------
Error:[ 0.17996497]
Pred:[0 1 1 0 1 0 0 0 0 1]
True:[0 1 1 0 1 0 0 0 0 1]
True
240 + 177 = 417
------------
Error:[ 0.14176455]
Pred:[0 0 1 0 1 1 1 0 0 1]
True:[0 0 1 0 1 1 1 0 0 1]
True
72 + 113 = 185
------------
Error:[ 0.21626245]
Pred:[0 1 0 1 0 1 0 1 1 1]
True:[0 1 0 1 0 1 0 1 1 1]
True
98 + 245 = 343
------------
Error:[ 0.2611522]
Pred:[1 0 1 0 1 1 0 0 0 0]
True:[1 0 1 0 1 1 0 0 0 0]
True
440 + 248 = 688
------------
Error:[ 0.16428819]
Pred:[0 0 1 0 1 0 1 1 0 1]
True:[0 0 1 0 1 0 1 1 0 1]
True
10 + 163 = 173
------------
Error:[ 0.20849092]
Pred:[1 0 1 0 1 0 1 1 0 1]
True:[1 0 1 0 1 0 1 1 0 1]
True
266 + 419 = 685
------------
Error:[ 0.2461282]
Pred:[0 1 1 0 1 1 1 1 1 1]
True:[0 1 1 0 1 1 1 1 1 1]
True
345 + 102 = 447
------------
Error:[ 0.31476154]
Pred:[1 0 0 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 0 0 1]
True
181 + 332 = 513
------------
Error:[ 0.16605737]
Pred:[0 1 0 0 1 1 1 1 0 1]
True:[0 1 0 0 1 1 1 1 0 1]
True
172 + 145 = 317
------------
Error:[ 0.27297363]
Pred:[0 1 1 0 0 1 0 0 0 1]
True:[0 1 1 0 0 1 0 0 0 1]
True
221 + 180 = 401
------------
Error:[ 0.2958943]
Pred:[1 1 0 0 1 1 0 1 1 0]
True:[1 1 0 0 1 1 0 1 1 0]
True
447 + 375 = 822
------------
Error:[ 0.31495512]
Pred:[1 0 0 0 0 1 0 0 0 1]
True:[1 0 0 0 0 1 0 0 0 1]
True
43 + 486 = 529
------------
Error:[ 0.3297667]
Pred:[1 1 0 0 0 1 1 1 0 0]
True:[1 1 0 0 0 1 1 1 0 0]
True
477 + 319 = 796
------------
Error:[ 0.26117005]
Pred:[0 1 1 0 0 1 1 0 0 0]
True:[0 1 1 0 0 1 1 0 0 0]
True
359 + 49 = 408
------------
Error:[ 0.35877471]
Pred:[1 1 0 0 0 0 1 1 0 1]
True:[1 1 0 0 0 0 1 1 0 1]
True
318 + 463 = 781
------------
Error:[ 0.21790624]
Pred:[0 1 0 1 1 1 0 1 0 1]
True:[0 1 0 1 1 1 0 1 0 1]
True
263 + 110 = 373
------------
Error:[ 0.30378748]
Pred:[1 0 0 1 1 0 1 1 0 1]
True:[1 0 0 1 1 0 1 1 0 1]
True
430 + 191 = 621
------------
Error:[ 0.28304278]
Pred:[1 0 0 1 1 1 0 0 1 1]
True:[1 0 0 1 1 1 0 0 1 1]
True
202 + 425 = 627
------------
Error:[ 0.19708022]
Pred:[1 0 0 1 0 0 1 1 0 0]
True:[1 0 0 1 0 0 1 1 0 0]
True
262 + 326 = 588
------------
Error:[ 0.15685053]
Pred:[1 1 0 1 0 0 1 0 1 1]
True:[1 1 0 1 0 0 1 0 1 1]
True
384 + 459 = 843
------------
Error:[ 0.12434756]
Pred:[0 1 0 0 1 0 0 1 1 0]
True:[0 1 0 0 1 0 0 1 1 0]
True
256 + 38 = 294
------------
Error:[ 0.20826115]
Pred:[0 1 1 1 1 1 0 1 1 0]
True:[0 1 1 1 1 1 0 1 1 0]
True
191 + 311 = 502
------------
Error:[ 0.166506]
Pred:[0 1 1 1 1 1 1 1 0 0]
True:[0 1 1 1 1 1 1 1 0 0]
True
431 + 77 = 508
------------
Error:[ 0.18177219]
Pred:[0 1 0 1 0 0 1 0 1 0]
True:[0 1 0 1 0 0 1 0 1 0]
True
169 + 161 = 330
------------
Error:[ 0.29100458]
Pred:[0 0 1 1 0 1 1 0 0 0]
True:[0 0 1 1 0 1 1 0 0 0]
True
158 + 58 = 216
------------
Error:[ 0.10837928]
Pred:[0 0 0 0 1 1 1 0 0 1]
True:[0 0 0 0 1 1 1 0 0 1]
True
56 + 1 = 57
------------
Error:[ 0.28396873]
Pred:[0 1 1 0 0 1 0 1 1 1]
True:[0 1 1 0 0 1 0 1 1 1]
True
125 + 282 = 407
------------
Error:[ 0.1608341]
Pred:[1 1 0 0 0 1 0 1 0 1]
True:[1 1 0 0 0 1 0 1 0 1]
True
324 + 465 = 789
------------
Error:[ 0.28731671]
Pred:[1 0 0 0 0 1 1 1 1 1]
True:[1 0 0 0 0 1 1 1 1 1]
True
288 + 255 = 543
------------
Error:[ 0.15308156]
Pred:[0 1 0 1 0 0 0 0 1 1]
True:[0 1 0 1 0 0 0 0 1 1]
True
128 + 195 = 323
------------
Error:[ 0.15743634]
Pred:[0 0 0 1 0 1 1 1 0 1]
True:[0 0 0 1 0 1 1 1 0 1]
True
71 + 22 = 93
------------
Error:[ 0.21407874]
Pred:[1 0 0 0 0 0 1 1 1 0]
True:[1 0 0 0 0 0 1 1 1 0]
True
365 + 161 = 526
------------
Error:[ 0.2708268]
Pred:[0 1 0 1 0 0 0 1 0 1]
True:[0 1 0 1 0 0 0 1 0 1]
True
287 + 38 = 325
------------
Error:[ 0.22848946]
Pred:[1 0 0 1 1 0 0 1 0 0]
True:[1 0 0 1 1 0 0 1 0 0]
True
176 + 436 = 612
------------
Error:[ 0.23140022]
Pred:[1 1 0 1 0 1 0 1 1 0]
True:[1 1 0 1 0 1 0 1 1 0]
True
418 + 436 = 854
------------
Error:[ 0.27559053]
Pred:[1 1 1 0 0 1 1 0 1 0]
True:[1 1 1 0 0 1 1 0 1 0]
True
413 + 509 = 922
------------
Error:[ 0.20241596]
Pred:[0 0 1 1 0 0 0 0 1 1]
True:[0 0 1 1 0 0 0 0 1 1]
True
48 + 147 = 195
------------
Error:[ 0.16887661]
Pred:[0 0 1 1 0 1 1 1 1 1]
True:[0 0 1 1 0 1 1 1 1 1]
True
81 + 142 = 223
------------
Error:[ 0.26463917]
Pred:[0 1 1 1 1 0 0 0 1 1]
True:[0 1 1 1 1 0 0 0 1 1]
True
118 + 365 = 483
------------
Error:[ 0.28126299]
Pred:[1 0 1 1 0 0 0 0 0 1]
True:[1 0 1 1 0 0 0 0 0 1]
True
196 + 509 = 705
------------
Error:[ 0.24976575]
Pred:[0 1 1 1 0 1 1 1 1 1]
True:[0 1 1 1 0 1 1 1 1 1]
True
45 + 434 = 479
------------
Error:[ 0.2716244]
Pred:[0 1 0 0 0 1 0 0 1 1]
True:[0 1 0 0 0 1 0 0 1 1]
True
89 + 186 = 275
------------
Error:[ 0.21204717]
Pred:[0 1 0 1 1 1 0 0 0 1]
True:[0 1 0 1 1 1 0 0 0 1]
True
330 + 39 = 369
------------
Error:[ 0.1601862]
Pred:[0 0 1 1 0 1 1 1 0 0]
True:[0 0 1 1 0 1 1 1 0 0]
True
18 + 202 = 220
------------
Error:[ 0.11763846]
Pred:[0 1 1 1 1 1 1 1 0 0]
True:[0 1 1 1 1 1 1 1 0 0]
True
276 + 232 = 508
------------
Error:[ 0.3060972]
Pred:[0 1 1 0 0 0 0 1 1 0]
True:[0 1 1 0 0 0 0 1 1 0]
True
92 + 298 = 390
------------
Error:[ 0.28846954]
Pred:[0 1 1 0 1 1 1 0 1 1]
True:[0 1 1 0 1 1 1 0 1 1]
True
102 + 341 = 443
------------
Error:[ 0.36933459]
Pred:[1 1 1 1 0 0 0 1 0 0]
True:[1 1 1 1 0 0 0 1 0 0]
True
494 + 470 = 964
------------
Error:[ 0.31882406]
Pred:[1 0 0 1 0 1 0 0 0 0]
True:[1 0 0 1 0 1 0 0 0 0]
True
244 + 348 = 592
------------
Error:[ 0.19395465]
Pred:[0 0 0 0 1 1 0 0 0 0]
True:[0 0 0 0 1 1 0 0 0 0]
True
45 + 3 = 48
------------
Error:[ 0.16826027]
Pred:[0 1 0 1 0 1 1 0 1 1]
True:[0 1 0 1 0 1 1 0 1 1]
True
129 + 218 = 347
------------
Error:[ 0.28807155]
Pred:[0 1 0 0 1 0 1 0 1 1]
True:[0 1 0 0 1 0 1 0 1 1]
True
94 + 205 = 299
------------
Error:[ 0.14074882]
Pred:[0 0 0 1 0 0 1 1 1 1]
True:[0 0 0 1 0 0 1 1 1 1]
True
8 + 71 = 79
------------
Error:[ 0.33998782]
Pred:[1 0 0 0 1 1 1 1 0 1]
True:[1 0 0 0 1 1 1 1 0 1]
True
99 + 474 = 573
------------
Error:[ 0.20388869]
Pred:[0 1 1 1 1 1 1 0 0 0]
True:[0 1 1 1 1 1 1 0 0 0]
True
409 + 95 = 504
------------
Error:[ 0.27113412]
Pred:[0 1 1 0 0 0 0 1 1 1]
True:[0 1 1 0 0 0 0 1 1 1]
True
60 + 331 = 391
------------
Error:[ 0.27800925]
Pred:[1 0 0 0 0 1 0 0 0 1]
True:[1 0 0 0 0 1 0 0 0 1]
True
356 + 173 = 529
------------
Error:[ 0.22105296]
Pred:[1 1 1 0 0 0 0 1 0 1]
True:[1 1 1 0 0 0 0 1 0 1]
True
437 + 464 = 901
------------
Error:[ 0.2230257]
Pred:[1 1 1 0 1 0 0 0 0 1]
True:[1 1 1 0 1 0 0 0 0 1]
True
464 + 465 = 929
------------
Error:[ 0.18593703]
Pred:[0 1 0 1 0 1 1 1 1 1]
True:[0 1 0 1 0 1 1 1 1 1]
True
52 + 299 = 351
------------
Error:[ 0.27264573]
Pred:[0 1 1 1 0 1 0 0 0 0]
True:[0 1 1 1 0 1 0 0 0 0]
True
429 + 35 = 464
------------
Error:[ 0.2599256]
Pred:[1 0 0 1 0 1 1 1 1 1]
True:[1 0 0 1 0 1 1 1 1 1]
True
430 + 177 = 607
------------
Error:[ 0.20877375]
Pred:[1 0 1 0 1 0 0 1 1 0]
True:[1 0 1 0 1 0 0 1 1 0]
True
393 + 285 = 678
------------
Error:[ 0.32661152]
Pred:[1 0 0 0 0 0 0 1 1 1]
True:[1 0 0 0 0 0 0 1 1 1]
True
317 + 202 = 519
------------
Error:[ 0.28557791]
Pred:[1 0 1 0 0 1 0 0 0 1]
True:[1 0 1 0 0 1 0 0 0 1]
True
504 + 153 = 657
------------
Error:[ 0.21642076]
Pred:[1 0 0 0 1 1 0 1 0 1]
True:[1 0 0 0 1 1 0 1 0 1]
True
402 + 163 = 565
------------
Error:[ 0.20254958]
Pred:[0 1 1 1 0 1 0 1 1 0]
True:[0 1 1 1 0 1 0 1 1 0]
True
432 + 38 = 470
------------
Error:[ 0.15904447]
Pred:[1 0 0 1 1 1 1 0 1 1]
True:[1 0 0 1 1 1 1 0 1 1]
True
282 + 353 = 635
------------
Error:[ 0.20761984]
Pred:[1 0 1 0 1 1 0 1 0 1]
True:[1 0 1 0 1 1 0 1 0 1]
True
387 + 306 = 693
------------
Error:[ 0.29115035]
Pred:[1 0 0 1 1 0 0 1 1 0]
True:[1 0 0 1 1 0 0 1 1 0]
True
463 + 151 = 614
------------
Error:[ 0.27264868]
Pred:[1 1 0 0 0 1 0 0 0 0]
True:[1 1 0 0 0 1 0 0 0 0]
True
408 + 376 = 784
------------
Error:[ 0.30618218]
Pred:[1 0 1 0 1 1 0 1 0 1]
True:[1 0 1 0 1 1 0 1 0 1]
True
471 + 222 = 693
------------
Error:[ 0.29745122]
Pred:[0 1 1 1 0 0 0 1 1 0]
True:[0 1 1 1 0 0 0 1 1 0]
True
254 + 200 = 454
------------
Error:[ 0.21469789]
Pred:[0 1 0 1 1 1 1 1 1 1]
True:[0 1 0 1 1 1 1 1 1 1]
True
188 + 195 = 383
------------
Error:[ 0.28374133]
Pred:[1 0 1 0 1 0 0 1 0 1]
True:[1 0 1 0 1 0 0 1 0 1]
True
237 + 440 = 677
------------
Error:[ 0.28252835]
Pred:[0 1 0 0 0 1 0 1 0 1]
True:[0 1 0 0 0 1 0 1 0 1]
True
111 + 166 = 277
------------
Error:[ 0.09764052]
Pred:[0 1 1 0 1 1 0 1 1 0]
True:[0 1 1 0 1 1 0 1 1 0]
True
3 + 435 = 438
------------
Error:[ 0.20357224]
Pred:[0 1 1 0 1 1 0 0 1 1]
True:[0 1 1 0 1 1 0 0 1 1]
True
267 + 168 = 435
------------
Error:[ 0.18996958]
Pred:[1 1 0 1 1 1 1 0 1 0]
True:[1 1 0 1 1 1 1 0 1 0]
True
418 + 472 = 890
------------
Error:[ 0.15201967]
Pred:[0 0 1 1 1 0 1 1 1 1]
True:[0 0 1 1 1 0 1 1 1 1]
True
64 + 175 = 239
------------
Error:[ 0.16760019]
Pred:[1 0 1 0 0 0 0 0 1 1]
True:[1 0 1 0 0 0 0 0 1 1]
True
387 + 256 = 643
------------
Error:[ 0.25905665]
Pred:[0 1 1 0 0 0 0 1 1 1]
True:[0 1 1 0 0 0 0 1 1 1]
True
251 + 140 = 391
------------
Error:[ 0.24378146]
Pred:[1 0 0 0 1 0 0 1 1 0]
True:[1 0 0 0 1 0 0 1 1 0]
True
497 + 53 = 550
------------

error와 accuracy가 유지되는 것을 확인

  • Check if the error and accuracy are maintained low and high
In [12]:
#print(overallError_history)
x_range = range(max_iter//100)
plt.plot(x_range,overallError_history,'r-')
plt.ylabel('overallError')
plt.show()

plt.plot(x_range,accuracy_history,'b-')
plt.ylabel('accuracy')
plt.show()



댓글 없음:

댓글 쓰기