TensorFlow

TensorFlow 기초 25 - MNIST로 CNN 처리

코딩탕탕 2022. 12. 7. 12:03

 

 

# MNIST로 CNN 처리
# 1) Conv(이미지 특징 추출) + Pooling(Conv 결과를 샘플링 - Conv 결과인 Feature map의 크기를 다시 줄임)
# 2) 원래의 이미지 크기를 줄여 최종적으로 FCLayer를 진행(Conv + Pooling 결과 다차원 배열 자료를 1차원으로 만들어 한 줄로 세우기)
# 3) Dense 층으로 넘겨 분류 작업 수행

import tensorflow as tf
from keras import datasets, layers, models

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape) # (60000, 28, 28) (60000,) (10000, 28, 28) (10000,)

# CNN은 채널을 사용하기 때문에 3차원 데이터를 4차원으로 변경
x_train = x_train.reshape((60000, 28, 28, 1)) # 흑백은 channel이 1개
x_test = x_test.reshape((10000, 28, 28, 1)) # 예) x_test[3, 12, 13, 1]
# print(x_train.ndim)
# print(x_train[:1])

x_train = x_train / 255.0
x_test = x_test / 255.0
# print(x_train[:1])

# label의 원핫 처리는 model에게 위임

# model
input_shape = (28, 28, 1)

# Sequential api 사용
model = models.Sequential()

model.add(layers.Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1), padding='valid',
                        activation='relu', input_shape=input_shape))
model.add(layers.MaxPool2D(pool_size=(2, 2)))
model.add(layers.Dropout(rate=0.2))

model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding='valid', activation='relu'))
model.add(layers.MaxPool2D(pool_size=(2, 2))) # filters는 Dense의 unit과 비슷한 개념이다.
model.add(layers.Dropout(rate=0.2))

model.add(layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='valid', activation='relu'))
model.add(layers.MaxPool2D(pool_size=(2, 2))) # kernel_size 명은 생략 가능
model.add(layers.Dropout(rate=0.2))

model.add(layers.Flatten()) # FCLayer(Fully Connected Layer) : 2차원 1차원으로 변경(모든 배열자료를 한 줄로 세움)

model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(rate=0.2))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dropout(rate=0.2))
model.add(layers.Dense(10, activation='softmax'))

print(model.summary())

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\

from keras.callbacks import EarlyStopping
es = EarlyStopping(monitor='val_loss', patience=3) # patience는 많이 줘야됨

history = model.fit(x_train, y_train, batch_size=128, epochs=1000, verbose=0, validation_split=0.2,
                    callbacks=[es])
# history 저장
import pickle
history = history.history
with open('cnn2_history.pickle', 'wb') as f:
    pickle.dump(history, f)

# 모델 평가
train_loss, train_acc = model.evaluate(x_train, y_train)
test_loss, test_acc = model.evaluate(x_test, y_test)
print('train_loss : {}, train_acc : {}'.format(train_loss, train_acc))
print('test_loss : {}, test_acc : {}'.format(test_loss, test_acc))

# 모델 저장
model.save('cnn2_model.h5')

print()
# --- 학습된 모델로 작업 ---
mymodel = tf.keras.models.load_model('cnn2_model.h5')

# predict
import numpy as np
print('예측값 :', np.argmax(mymodel.predict(x_test[:1])))
print('예측값 :', np.argmax(mymodel.predict(x_test[[0]]))) # 위랑 같은 의미
print('실제값 :', y_test[0])

# 시각화
import matplotlib.pyplot as plt

with open('cnn2_history.pickle', 'rb') as f:
    history = pickle.load(f)
    
def plot_acc_func(title=None):
    plt.plot(history['accuracy'], label='accuracy')
    plt.plot(history['val_accuracy'], label='val_accuracy')
    plt.title(title)
    plt.xlabel('epochs')
    plt.ylabel(title)
    plt.legend()
    
plot_acc_func('accuracy')
plt.show()

def plot_loss_func(title=None):
    plt.plot(history['loss'], label='loss')
    plt.plot(history['val_loss'], label='val_loss')
    plt.title(title)
    plt.xlabel('epochs')
    plt.ylabel(title)
    plt.legend()
    
plot_acc_func('loss')
plt.show()


<console>
(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 16)        160       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 16)       0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 13, 13, 16)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 32)        4640      
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 5, 5, 32)         0         
 2D)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 5, 5, 32)          0         
                                                                 
 conv2d_2 (Conv2D)           (None, 3, 3, 64)          18496     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 1, 1, 64)         0         
 2D)                                                             
                                                                 
 dropout_2 (Dropout)         (None, 1, 1, 64)          0         
                                                                 
 flatten (Flatten)           (None, 64)                0         
                                                                 
 dense (Dense)               (None, 64)                4160      
                                                                 
 dropout_3 (Dropout)         (None, 64)                0         
                                                                 
 dense_1 (Dense)             (None, 32)                2080      
                                                                 
 dropout_4 (Dropout)         (None, 32)                0         
                                                                 
 dense_2 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 29,866
Trainable params: 29,866
Non-trainable params: 0
_________________________________________________________________
None

   1/1875 [..............................] - ETA: 31s - loss: 0.0028 - accuracy: 1.0000
  21/1875 [..............................] - ETA: 4s - loss: 0.0261 - accuracy: 0.9896 
  40/1875 [..............................] - ETA: 4s - loss: 0.0370 - accuracy: 0.9883
  58/1875 [..............................] - ETA: 4s - loss: 0.0316 - accuracy: 0.9914
  77/1875 [>.............................] - ETA: 4s - loss: 0.0297 - accuracy: 0.9919
  96/1875 [>.............................] - ETA: 4s - loss: 0.0343 - accuracy: 0.9906
 114/1875 [>.............................] - ETA: 4s - loss: 0.0309 - accuracy: 0.9915
 134/1875 [=>............................] - ETA: 4s - loss: 0.0294 - accuracy: 0.9916
 154/1875 [=>............................] - ETA: 4s - loss: 0.0280 - accuracy: 0.9919
 174/1875 [=>............................] - ETA: 4s - loss: 0.0280 - accuracy: 0.9916
 193/1875 [==>...........................] - ETA: 4s - loss: 0.0275 - accuracy: 0.9916
 211/1875 [==>...........................] - ETA: 4s - loss: 0.0273 - accuracy: 0.9917
 227/1875 [==>...........................] - ETA: 4s - loss: 0.0277 - accuracy: 0.9916
 246/1875 [==>...........................] - ETA: 4s - loss: 0.0270 - accuracy: 0.9916
 266/1875 [===>..........................] - ETA: 4s - loss: 0.0281 - accuracy: 0.9914
 285/1875 [===>..........................] - ETA: 4s - loss: 0.0290 - accuracy: 0.9911
 305/1875 [===>..........................] - ETA: 4s - loss: 0.0292 - accuracy: 0.9912
 325/1875 [====>.........................] - ETA: 4s - loss: 0.0287 - accuracy: 0.9914
 343/1875 [====>.........................] - ETA: 4s - loss: 0.0288 - accuracy: 0.9916
 362/1875 [====>.........................] - ETA: 4s - loss: 0.0284 - accuracy: 0.9920
 381/1875 [=====>........................] - ETA: 4s - loss: 0.0283 - accuracy: 0.9920
 400/1875 [=====>........................] - ETA: 4s - loss: 0.0277 - accuracy: 0.9921
 419/1875 [=====>........................] - ETA: 3s - loss: 0.0276 - accuracy: 0.9920
 438/1875 [======>.......................] - ETA: 3s - loss: 0.0278 - accuracy: 0.9919
 457/1875 [======>.......................] - ETA: 3s - loss: 0.0275 - accuracy: 0.9920
 477/1875 [======>.......................] - ETA: 3s - loss: 0.0268 - accuracy: 0.9921
 497/1875 [======>.......................] - ETA: 3s - loss: 0.0271 - accuracy: 0.9922
 517/1875 [=======>......................] - ETA: 3s - loss: 0.0267 - accuracy: 0.9923
 538/1875 [=======>......................] - ETA: 3s - loss: 0.0266 - accuracy: 0.9923
 558/1875 [=======>......................] - ETA: 3s - loss: 0.0267 - accuracy: 0.9924
 577/1875 [========>.....................] - ETA: 3s - loss: 0.0265 - accuracy: 0.9925
 597/1875 [========>.....................] - ETA: 3s - loss: 0.0258 - accuracy: 0.9928
 617/1875 [========>.....................] - ETA: 3s - loss: 0.0256 - accuracy: 0.9927
 637/1875 [=========>....................] - ETA: 3s - loss: 0.0262 - accuracy: 0.9925
 658/1875 [=========>....................] - ETA: 3s - loss: 0.0261 - accuracy: 0.9925
 678/1875 [=========>....................] - ETA: 3s - loss: 0.0260 - accuracy: 0.9923
 699/1875 [==========>...................] - ETA: 3s - loss: 0.0255 - accuracy: 0.9925
 719/1875 [==========>...................] - ETA: 3s - loss: 0.0256 - accuracy: 0.9925
 738/1875 [==========>...................] - ETA: 3s - loss: 0.0253 - accuracy: 0.9925
 759/1875 [===========>..................] - ETA: 2s - loss: 0.0254 - accuracy: 0.9925
 779/1875 [===========>..................] - ETA: 2s - loss: 0.0258 - accuracy: 0.9925
 800/1875 [===========>..................] - ETA: 2s - loss: 0.0256 - accuracy: 0.9926
 820/1875 [============>.................] - ETA: 2s - loss: 0.0254 - accuracy: 0.9926
 840/1875 [============>.................] - ETA: 2s - loss: 0.0259 - accuracy: 0.9924
 861/1875 [============>.................] - ETA: 2s - loss: 0.0260 - accuracy: 0.9924
 881/1875 [=============>................] - ETA: 2s - loss: 0.0258 - accuracy: 0.9923
 901/1875 [=============>................] - ETA: 2s - loss: 0.0262 - accuracy: 0.9923
 922/1875 [=============>................] - ETA: 2s - loss: 0.0259 - accuracy: 0.9923
 941/1875 [==============>...............] - ETA: 2s - loss: 0.0257 - accuracy: 0.9925
 961/1875 [==============>...............] - ETA: 2s - loss: 0.0254 - accuracy: 0.9925
 982/1875 [==============>...............] - ETA: 2s - loss: 0.0252 - accuracy: 0.9926
1003/1875 [===============>..............] - ETA: 2s - loss: 0.0251 - accuracy: 0.9926
1023/1875 [===============>..............] - ETA: 2s - loss: 0.0252 - accuracy: 0.9926
1043/1875 [===============>..............] - ETA: 2s - loss: 0.0250 - accuracy: 0.9927
1062/1875 [===============>..............] - ETA: 2s - loss: 0.0250 - accuracy: 0.9927
1082/1875 [================>.............] - ETA: 2s - loss: 0.0251 - accuracy: 0.9927
1103/1875 [================>.............] - ETA: 2s - loss: 0.0252 - accuracy: 0.9927
1124/1875 [================>.............] - ETA: 1s - loss: 0.0253 - accuracy: 0.9927
1145/1875 [=================>............] - ETA: 1s - loss: 0.0252 - accuracy: 0.9927
1166/1875 [=================>............] - ETA: 1s - loss: 0.0254 - accuracy: 0.9927
1185/1875 [=================>............] - ETA: 1s - loss: 0.0256 - accuracy: 0.9926
1204/1875 [==================>...........] - ETA: 1s - loss: 0.0260 - accuracy: 0.9925
1224/1875 [==================>...........] - ETA: 1s - loss: 0.0257 - accuracy: 0.9925
1245/1875 [==================>...........] - ETA: 1s - loss: 0.0261 - accuracy: 0.9924
1265/1875 [===================>..........] - ETA: 1s - loss: 0.0261 - accuracy: 0.9924
1285/1875 [===================>..........] - ETA: 1s - loss: 0.0262 - accuracy: 0.9924
1305/1875 [===================>..........] - ETA: 1s - loss: 0.0265 - accuracy: 0.9923
1324/1875 [====================>.........] - ETA: 1s - loss: 0.0264 - accuracy: 0.9924
1345/1875 [====================>.........] - ETA: 1s - loss: 0.0263 - accuracy: 0.9924
1366/1875 [====================>.........] - ETA: 1s - loss: 0.0264 - accuracy: 0.9924
1387/1875 [=====================>........] - ETA: 1s - loss: 0.0262 - accuracy: 0.9924
1408/1875 [=====================>........] - ETA: 1s - loss: 0.0260 - accuracy: 0.9925
1428/1875 [=====================>........] - ETA: 1s - loss: 0.0260 - accuracy: 0.9925
1449/1875 [======================>.......] - ETA: 1s - loss: 0.0262 - accuracy: 0.9924
1470/1875 [======================>.......] - ETA: 1s - loss: 0.0263 - accuracy: 0.9924
1491/1875 [======================>.......] - ETA: 0s - loss: 0.0264 - accuracy: 0.9924
1511/1875 [=======================>......] - ETA: 0s - loss: 0.0264 - accuracy: 0.9923
1532/1875 [=======================>......] - ETA: 0s - loss: 0.0265 - accuracy: 0.9923
1552/1875 [=======================>......] - ETA: 0s - loss: 0.0270 - accuracy: 0.9921
1572/1875 [========================>.....] - ETA: 0s - loss: 0.0273 - accuracy: 0.9920
1593/1875 [========================>.....] - ETA: 0s - loss: 0.0277 - accuracy: 0.9918
1613/1875 [========================>.....] - ETA: 0s - loss: 0.0278 - accuracy: 0.9918
1634/1875 [=========================>....] - ETA: 0s - loss: 0.0283 - accuracy: 0.9917
1655/1875 [=========================>....] - ETA: 0s - loss: 0.0287 - accuracy: 0.9915
1675/1875 [=========================>....] - ETA: 0s - loss: 0.0291 - accuracy: 0.9915
1694/1875 [==========================>...] - ETA: 0s - loss: 0.0294 - accuracy: 0.9914
1715/1875 [==========================>...] - ETA: 0s - loss: 0.0295 - accuracy: 0.9914
1735/1875 [==========================>...] - ETA: 0s - loss: 0.0296 - accuracy: 0.9914
1756/1875 [===========================>..] - ETA: 0s - loss: 0.0296 - accuracy: 0.9914
1777/1875 [===========================>..] - ETA: 0s - loss: 0.0299 - accuracy: 0.9913
1798/1875 [===========================>..] - ETA: 0s - loss: 0.0300 - accuracy: 0.9913
1817/1875 [============================>.] - ETA: 0s - loss: 0.0301 - accuracy: 0.9912
1838/1875 [============================>.] - ETA: 0s - loss: 0.0299 - accuracy: 0.9913
1858/1875 [============================>.] - ETA: 0s - loss: 0.0296 - accuracy: 0.9914
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0301 - accuracy: 0.9913

  1/313 [..............................] - ETA: 5s - loss: 0.0028 - accuracy: 1.0000
 20/313 [>.............................] - ETA: 0s - loss: 0.0270 - accuracy: 0.9891
 39/313 [==>...........................] - ETA: 0s - loss: 0.0525 - accuracy: 0.9832
 59/313 [====>.........................] - ETA: 0s - loss: 0.0641 - accuracy: 0.9831
 79/313 [======>.......................] - ETA: 0s - loss: 0.0722 - accuracy: 0.9818
 99/313 [========>.....................] - ETA: 0s - loss: 0.0713 - accuracy: 0.9830
119/313 [==========>...................] - ETA: 0s - loss: 0.0666 - accuracy: 0.9835
138/313 [============>.................] - ETA: 0s - loss: 0.0660 - accuracy: 0.9832
157/313 [==============>...............] - ETA: 0s - loss: 0.0661 - accuracy: 0.9825
177/313 [===============>..............] - ETA: 0s - loss: 0.0598 - accuracy: 0.9839
197/313 [=================>............] - ETA: 0s - loss: 0.0562 - accuracy: 0.9848
217/313 [===================>..........] - ETA: 0s - loss: 0.0582 - accuracy: 0.9847
237/313 [=====================>........] - ETA: 0s - loss: 0.0541 - accuracy: 0.9856
256/313 [=======================>......] - ETA: 0s - loss: 0.0505 - accuracy: 0.9866
275/313 [=========================>....] - ETA: 0s - loss: 0.0483 - accuracy: 0.9869
294/313 [===========================>..] - ETA: 0s - loss: 0.0457 - accuracy: 0.9877
313/313 [==============================] - 1s 3ms/step - loss: 0.0486 - accuracy: 0.9870
train_loss : 0.030139632523059845, train_acc : 0.9913166761398315
test_loss : 0.04862738028168678, test_acc : 0.9869999885559082


1/1 [==============================] - ETA: 0s
1/1 [==============================] - 0s 92ms/step
예측값 : 7

1/1 [==============================] - ETA: 0s
1/1 [==============================] - 0s 13ms/step
예측값 : 7
실제값 : 7

 

CNN은 채널을 사용하기 때문에 3차원 그 전에 데이터를 4차원으로 변경해야 된다.(예) x_test[3, 12, 13, 1])

label의 원핫 처리는 model에게 위임하였다.

 

accuracy 시각화

 

loss 시각화