ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • TensorFlow 기초 25 - MNIST로 CNN 처리
    TensorFlow 2022. 12. 7. 12:03

     

     

    # MNIST로 CNN 처리
    # 1) Conv(이미지 특징 추출) + Pooling(Conv 결과를 샘플링 - Conv 결과인 Feature map의 크기를 다시 줄임)
    # 2) 원래의 이미지 크기를 줄여 최종적으로 FCLayer를 진행(Conv + Pooling 결과 다차원 배열 자료를 1차원으로 만들어 한 줄로 세우기)
    # 3) Dense 층으로 넘겨 분류 작업 수행
    
    import tensorflow as tf
    from keras import datasets, layers, models
    
    (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
    print(x_train.shape, y_train.shape, x_test.shape, y_test.shape) # (60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
    
    # CNN은 채널을 사용하기 때문에 3차원 데이터를 4차원으로 변경
    x_train = x_train.reshape((60000, 28, 28, 1)) # 흑백은 channel이 1개
    x_test = x_test.reshape((10000, 28, 28, 1)) # 예) x_test[3, 12, 13, 1]
    # print(x_train.ndim)
    # print(x_train[:1])
    
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    # print(x_train[:1])
    
    # label의 원핫 처리는 model에게 위임
    
    # model
    input_shape = (28, 28, 1)
    
    # Sequential api 사용
    model = models.Sequential()
    
    model.add(layers.Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1), padding='valid',
                            activation='relu', input_shape=input_shape))
    model.add(layers.MaxPool2D(pool_size=(2, 2)))
    model.add(layers.Dropout(rate=0.2))
    
    model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding='valid', activation='relu'))
    model.add(layers.MaxPool2D(pool_size=(2, 2))) # filters는 Dense의 unit과 비슷한 개념이다.
    model.add(layers.Dropout(rate=0.2))
    
    model.add(layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='valid', activation='relu'))
    model.add(layers.MaxPool2D(pool_size=(2, 2))) # kernel_size 명은 생략 가능
    model.add(layers.Dropout(rate=0.2))
    
    model.add(layers.Flatten()) # FCLayer(Fully Connected Layer) : 2차원 1차원으로 변경(모든 배열자료를 한 줄로 세움)
    
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dropout(rate=0.2))
    model.add(layers.Dense(32, activation='relu'))
    model.add(layers.Dropout(rate=0.2))
    model.add(layers.Dense(10, activation='softmax'))
    
    print(model.summary())
    
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\
    
    from keras.callbacks import EarlyStopping
    es = EarlyStopping(monitor='val_loss', patience=3) # patience는 많이 줘야됨
    
    history = model.fit(x_train, y_train, batch_size=128, epochs=1000, verbose=0, validation_split=0.2,
                        callbacks=[es])
    # history 저장
    import pickle
    history = history.history
    with open('cnn2_history.pickle', 'wb') as f:
        pickle.dump(history, f)
    
    # 모델 평가
    train_loss, train_acc = model.evaluate(x_train, y_train)
    test_loss, test_acc = model.evaluate(x_test, y_test)
    print('train_loss : {}, train_acc : {}'.format(train_loss, train_acc))
    print('test_loss : {}, test_acc : {}'.format(test_loss, test_acc))
    
    # 모델 저장
    model.save('cnn2_model.h5')
    
    print()
    # --- 학습된 모델로 작업 ---
    mymodel = tf.keras.models.load_model('cnn2_model.h5')
    
    # predict
    import numpy as np
    print('예측값 :', np.argmax(mymodel.predict(x_test[:1])))
    print('예측값 :', np.argmax(mymodel.predict(x_test[[0]]))) # 위랑 같은 의미
    print('실제값 :', y_test[0])
    
    # 시각화
    import matplotlib.pyplot as plt
    
    with open('cnn2_history.pickle', 'rb') as f:
        history = pickle.load(f)
        
    def plot_acc_func(title=None):
        plt.plot(history['accuracy'], label='accuracy')
        plt.plot(history['val_accuracy'], label='val_accuracy')
        plt.title(title)
        plt.xlabel('epochs')
        plt.ylabel(title)
        plt.legend()
        
    plot_acc_func('accuracy')
    plt.show()
    
    def plot_loss_func(title=None):
        plt.plot(history['loss'], label='loss')
        plt.plot(history['val_loss'], label='val_loss')
        plt.title(title)
        plt.xlabel('epochs')
        plt.ylabel(title)
        plt.legend()
        
    plot_acc_func('loss')
    plt.show()
    
    
    <console>
    (60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
    
    Model: "sequential"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     conv2d (Conv2D)             (None, 26, 26, 16)        160       
                                                                     
     max_pooling2d (MaxPooling2D  (None, 13, 13, 16)       0         
     )                                                               
                                                                     
     dropout (Dropout)           (None, 13, 13, 16)        0         
                                                                     
     conv2d_1 (Conv2D)           (None, 11, 11, 32)        4640      
                                                                     
     max_pooling2d_1 (MaxPooling  (None, 5, 5, 32)         0         
     2D)                                                             
                                                                     
     dropout_1 (Dropout)         (None, 5, 5, 32)          0         
                                                                     
     conv2d_2 (Conv2D)           (None, 3, 3, 64)          18496     
                                                                     
     max_pooling2d_2 (MaxPooling  (None, 1, 1, 64)         0         
     2D)                                                             
                                                                     
     dropout_2 (Dropout)         (None, 1, 1, 64)          0         
                                                                     
     flatten (Flatten)           (None, 64)                0         
                                                                     
     dense (Dense)               (None, 64)                4160      
                                                                     
     dropout_3 (Dropout)         (None, 64)                0         
                                                                     
     dense_1 (Dense)             (None, 32)                2080      
                                                                     
     dropout_4 (Dropout)         (None, 32)                0         
                                                                     
     dense_2 (Dense)             (None, 10)                330       
                                                                     
    =================================================================
    Total params: 29,866
    Trainable params: 29,866
    Non-trainable params: 0
    _________________________________________________________________
    None
    
       1/1875 [..............................] - ETA: 31s - loss: 0.0028 - accuracy: 1.0000
      21/1875 [..............................] - ETA: 4s - loss: 0.0261 - accuracy: 0.9896 
      40/1875 [..............................] - ETA: 4s - loss: 0.0370 - accuracy: 0.9883
      58/1875 [..............................] - ETA: 4s - loss: 0.0316 - accuracy: 0.9914
      77/1875 [>.............................] - ETA: 4s - loss: 0.0297 - accuracy: 0.9919
      96/1875 [>.............................] - ETA: 4s - loss: 0.0343 - accuracy: 0.9906
     114/1875 [>.............................] - ETA: 4s - loss: 0.0309 - accuracy: 0.9915
     134/1875 [=>............................] - ETA: 4s - loss: 0.0294 - accuracy: 0.9916
     154/1875 [=>............................] - ETA: 4s - loss: 0.0280 - accuracy: 0.9919
     174/1875 [=>............................] - ETA: 4s - loss: 0.0280 - accuracy: 0.9916
     193/1875 [==>...........................] - ETA: 4s - loss: 0.0275 - accuracy: 0.9916
     211/1875 [==>...........................] - ETA: 4s - loss: 0.0273 - accuracy: 0.9917
     227/1875 [==>...........................] - ETA: 4s - loss: 0.0277 - accuracy: 0.9916
     246/1875 [==>...........................] - ETA: 4s - loss: 0.0270 - accuracy: 0.9916
     266/1875 [===>..........................] - ETA: 4s - loss: 0.0281 - accuracy: 0.9914
     285/1875 [===>..........................] - ETA: 4s - loss: 0.0290 - accuracy: 0.9911
     305/1875 [===>..........................] - ETA: 4s - loss: 0.0292 - accuracy: 0.9912
     325/1875 [====>.........................] - ETA: 4s - loss: 0.0287 - accuracy: 0.9914
     343/1875 [====>.........................] - ETA: 4s - loss: 0.0288 - accuracy: 0.9916
     362/1875 [====>.........................] - ETA: 4s - loss: 0.0284 - accuracy: 0.9920
     381/1875 [=====>........................] - ETA: 4s - loss: 0.0283 - accuracy: 0.9920
     400/1875 [=====>........................] - ETA: 4s - loss: 0.0277 - accuracy: 0.9921
     419/1875 [=====>........................] - ETA: 3s - loss: 0.0276 - accuracy: 0.9920
     438/1875 [======>.......................] - ETA: 3s - loss: 0.0278 - accuracy: 0.9919
     457/1875 [======>.......................] - ETA: 3s - loss: 0.0275 - accuracy: 0.9920
     477/1875 [======>.......................] - ETA: 3s - loss: 0.0268 - accuracy: 0.9921
     497/1875 [======>.......................] - ETA: 3s - loss: 0.0271 - accuracy: 0.9922
     517/1875 [=======>......................] - ETA: 3s - loss: 0.0267 - accuracy: 0.9923
     538/1875 [=======>......................] - ETA: 3s - loss: 0.0266 - accuracy: 0.9923
     558/1875 [=======>......................] - ETA: 3s - loss: 0.0267 - accuracy: 0.9924
     577/1875 [========>.....................] - ETA: 3s - loss: 0.0265 - accuracy: 0.9925
     597/1875 [========>.....................] - ETA: 3s - loss: 0.0258 - accuracy: 0.9928
     617/1875 [========>.....................] - ETA: 3s - loss: 0.0256 - accuracy: 0.9927
     637/1875 [=========>....................] - ETA: 3s - loss: 0.0262 - accuracy: 0.9925
     658/1875 [=========>....................] - ETA: 3s - loss: 0.0261 - accuracy: 0.9925
     678/1875 [=========>....................] - ETA: 3s - loss: 0.0260 - accuracy: 0.9923
     699/1875 [==========>...................] - ETA: 3s - loss: 0.0255 - accuracy: 0.9925
     719/1875 [==========>...................] - ETA: 3s - loss: 0.0256 - accuracy: 0.9925
     738/1875 [==========>...................] - ETA: 3s - loss: 0.0253 - accuracy: 0.9925
     759/1875 [===========>..................] - ETA: 2s - loss: 0.0254 - accuracy: 0.9925
     779/1875 [===========>..................] - ETA: 2s - loss: 0.0258 - accuracy: 0.9925
     800/1875 [===========>..................] - ETA: 2s - loss: 0.0256 - accuracy: 0.9926
     820/1875 [============>.................] - ETA: 2s - loss: 0.0254 - accuracy: 0.9926
     840/1875 [============>.................] - ETA: 2s - loss: 0.0259 - accuracy: 0.9924
     861/1875 [============>.................] - ETA: 2s - loss: 0.0260 - accuracy: 0.9924
     881/1875 [=============>................] - ETA: 2s - loss: 0.0258 - accuracy: 0.9923
     901/1875 [=============>................] - ETA: 2s - loss: 0.0262 - accuracy: 0.9923
     922/1875 [=============>................] - ETA: 2s - loss: 0.0259 - accuracy: 0.9923
     941/1875 [==============>...............] - ETA: 2s - loss: 0.0257 - accuracy: 0.9925
     961/1875 [==============>...............] - ETA: 2s - loss: 0.0254 - accuracy: 0.9925
     982/1875 [==============>...............] - ETA: 2s - loss: 0.0252 - accuracy: 0.9926
    1003/1875 [===============>..............] - ETA: 2s - loss: 0.0251 - accuracy: 0.9926
    1023/1875 [===============>..............] - ETA: 2s - loss: 0.0252 - accuracy: 0.9926
    1043/1875 [===============>..............] - ETA: 2s - loss: 0.0250 - accuracy: 0.9927
    1062/1875 [===============>..............] - ETA: 2s - loss: 0.0250 - accuracy: 0.9927
    1082/1875 [================>.............] - ETA: 2s - loss: 0.0251 - accuracy: 0.9927
    1103/1875 [================>.............] - ETA: 2s - loss: 0.0252 - accuracy: 0.9927
    1124/1875 [================>.............] - ETA: 1s - loss: 0.0253 - accuracy: 0.9927
    1145/1875 [=================>............] - ETA: 1s - loss: 0.0252 - accuracy: 0.9927
    1166/1875 [=================>............] - ETA: 1s - loss: 0.0254 - accuracy: 0.9927
    1185/1875 [=================>............] - ETA: 1s - loss: 0.0256 - accuracy: 0.9926
    1204/1875 [==================>...........] - ETA: 1s - loss: 0.0260 - accuracy: 0.9925
    1224/1875 [==================>...........] - ETA: 1s - loss: 0.0257 - accuracy: 0.9925
    1245/1875 [==================>...........] - ETA: 1s - loss: 0.0261 - accuracy: 0.9924
    1265/1875 [===================>..........] - ETA: 1s - loss: 0.0261 - accuracy: 0.9924
    1285/1875 [===================>..........] - ETA: 1s - loss: 0.0262 - accuracy: 0.9924
    1305/1875 [===================>..........] - ETA: 1s - loss: 0.0265 - accuracy: 0.9923
    1324/1875 [====================>.........] - ETA: 1s - loss: 0.0264 - accuracy: 0.9924
    1345/1875 [====================>.........] - ETA: 1s - loss: 0.0263 - accuracy: 0.9924
    1366/1875 [====================>.........] - ETA: 1s - loss: 0.0264 - accuracy: 0.9924
    1387/1875 [=====================>........] - ETA: 1s - loss: 0.0262 - accuracy: 0.9924
    1408/1875 [=====================>........] - ETA: 1s - loss: 0.0260 - accuracy: 0.9925
    1428/1875 [=====================>........] - ETA: 1s - loss: 0.0260 - accuracy: 0.9925
    1449/1875 [======================>.......] - ETA: 1s - loss: 0.0262 - accuracy: 0.9924
    1470/1875 [======================>.......] - ETA: 1s - loss: 0.0263 - accuracy: 0.9924
    1491/1875 [======================>.......] - ETA: 0s - loss: 0.0264 - accuracy: 0.9924
    1511/1875 [=======================>......] - ETA: 0s - loss: 0.0264 - accuracy: 0.9923
    1532/1875 [=======================>......] - ETA: 0s - loss: 0.0265 - accuracy: 0.9923
    1552/1875 [=======================>......] - ETA: 0s - loss: 0.0270 - accuracy: 0.9921
    1572/1875 [========================>.....] - ETA: 0s - loss: 0.0273 - accuracy: 0.9920
    1593/1875 [========================>.....] - ETA: 0s - loss: 0.0277 - accuracy: 0.9918
    1613/1875 [========================>.....] - ETA: 0s - loss: 0.0278 - accuracy: 0.9918
    1634/1875 [=========================>....] - ETA: 0s - loss: 0.0283 - accuracy: 0.9917
    1655/1875 [=========================>....] - ETA: 0s - loss: 0.0287 - accuracy: 0.9915
    1675/1875 [=========================>....] - ETA: 0s - loss: 0.0291 - accuracy: 0.9915
    1694/1875 [==========================>...] - ETA: 0s - loss: 0.0294 - accuracy: 0.9914
    1715/1875 [==========================>...] - ETA: 0s - loss: 0.0295 - accuracy: 0.9914
    1735/1875 [==========================>...] - ETA: 0s - loss: 0.0296 - accuracy: 0.9914
    1756/1875 [===========================>..] - ETA: 0s - loss: 0.0296 - accuracy: 0.9914
    1777/1875 [===========================>..] - ETA: 0s - loss: 0.0299 - accuracy: 0.9913
    1798/1875 [===========================>..] - ETA: 0s - loss: 0.0300 - accuracy: 0.9913
    1817/1875 [============================>.] - ETA: 0s - loss: 0.0301 - accuracy: 0.9912
    1838/1875 [============================>.] - ETA: 0s - loss: 0.0299 - accuracy: 0.9913
    1858/1875 [============================>.] - ETA: 0s - loss: 0.0296 - accuracy: 0.9914
    1875/1875 [==============================] - 5s 3ms/step - loss: 0.0301 - accuracy: 0.9913
    
      1/313 [..............................] - ETA: 5s - loss: 0.0028 - accuracy: 1.0000
     20/313 [>.............................] - ETA: 0s - loss: 0.0270 - accuracy: 0.9891
     39/313 [==>...........................] - ETA: 0s - loss: 0.0525 - accuracy: 0.9832
     59/313 [====>.........................] - ETA: 0s - loss: 0.0641 - accuracy: 0.9831
     79/313 [======>.......................] - ETA: 0s - loss: 0.0722 - accuracy: 0.9818
     99/313 [========>.....................] - ETA: 0s - loss: 0.0713 - accuracy: 0.9830
    119/313 [==========>...................] - ETA: 0s - loss: 0.0666 - accuracy: 0.9835
    138/313 [============>.................] - ETA: 0s - loss: 0.0660 - accuracy: 0.9832
    157/313 [==============>...............] - ETA: 0s - loss: 0.0661 - accuracy: 0.9825
    177/313 [===============>..............] - ETA: 0s - loss: 0.0598 - accuracy: 0.9839
    197/313 [=================>............] - ETA: 0s - loss: 0.0562 - accuracy: 0.9848
    217/313 [===================>..........] - ETA: 0s - loss: 0.0582 - accuracy: 0.9847
    237/313 [=====================>........] - ETA: 0s - loss: 0.0541 - accuracy: 0.9856
    256/313 [=======================>......] - ETA: 0s - loss: 0.0505 - accuracy: 0.9866
    275/313 [=========================>....] - ETA: 0s - loss: 0.0483 - accuracy: 0.9869
    294/313 [===========================>..] - ETA: 0s - loss: 0.0457 - accuracy: 0.9877
    313/313 [==============================] - 1s 3ms/step - loss: 0.0486 - accuracy: 0.9870
    train_loss : 0.030139632523059845, train_acc : 0.9913166761398315
    test_loss : 0.04862738028168678, test_acc : 0.9869999885559082
    
    
    1/1 [==============================] - ETA: 0s
    1/1 [==============================] - 0s 92ms/step
    예측값 : 7
    
    1/1 [==============================] - ETA: 0s
    1/1 [==============================] - 0s 13ms/step
    예측값 : 7
    실제값 : 7

     

    CNN은 채널을 사용하기 때문에 3차원 그 전에 데이터를 4차원으로 변경해야 된다.(예) x_test[3, 12, 13, 1])

    label의 원핫 처리는 model에게 위임하였다.

     

    accuracy 시각화

     

    loss 시각화

    댓글

Designed by Tistory.