ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • TensorFlow 기초 20 - classification(이항분류 wine dataset), 학습 조기 종료, 모델 학습 시 모니터링 결과를 파일로 저장
    TensorFlow 2022. 12. 2. 17:00

     

     

    # red&white wine dataset으로 분류 모델 작성
    
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.callbacks import EarlyStopping, ModelCheckpoint
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split
    
    wdf = pd.read_csv('../testdata/wine.csv', header=None)
    print(wdf.head(2))
    print(wdf.info())
    print(wdf.iloc[:, 12].unique()) # [1 0]
    print(len(wdf[wdf.iloc[:, 12] == 0])) # 4898
    print(len(wdf[wdf.iloc[:, 12] == 1])) # 1599
    
    dataset = wdf.values
    x = dataset[:, 0:12]
    y = dataset[:, -1]
    print(x[:1])
    print(y[:1])
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=12)
    
    # model
    model = Sequential()
    model.add(Dense(units=32, input_dim=12, activation='relu'))
    model.add(Dense(units=16, activation='relu'))
    model.add(Dense(units=8, activation='relu'))
    model.add(Dense(units=1, activation='sigmoid'))
    print(model.summary())
    
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # 훈련하지 않고 평가
    loss, acc = model.evaluate(x_train, y_train, verbose=2)
    print('loss, acc :', loss, acc)
    
    # 학습 조기 종료
    early_stop = EarlyStopping(monitor='val_loss', patience=5)
    
    # 모델 학습 시 모니터링 결과를 파일로 저장(validation이 없으면 loss를 건다.)
    chkpoint = ModelCheckpoint(filepath='cl3model.hdf5', monitor='val_loss', verbose=0, save_best_only=True)
    
    history = model.fit(x_train, y_train, epochs=1000, batch_size=32, verbose=2,
                        validation_split=0.2,
                        callbacks=[early_stop, chkpoint])
    
    # 훈련하고 평가
    loss, acc = model.evaluate(x_test, y_test, verbose=2)
    print('loss, acc :', loss, acc)
    
    vloss = history.history['val_loss']
    loss = history.history['loss']
    print('vloss :', vloss)
    print('loss :', loss)
    
    vaccuracy = history.history['val_accuracy']
    accuracy = history.history['accuracy']
    print('vaccuracy :', vaccuracy)
    print('accuracy :', accuracy)
    
    # 시각화
    epoch_len = np.arange(len(accuracy))
    
    plt.plot(epoch_len, vloss, c='red', label='val_loss')
    plt.plot(epoch_len, loss, c='blue', label='loss')
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.legend(loc='best')
    plt.show()
    
    plt.plot(epoch_len, vaccuracy, c='red', label='vaccuracy')
    plt.plot(epoch_len, accuracy, c='blue', label='accuracy')
    plt.xlabel('epochs')
    plt.ylabel('accuracy')
    plt.legend(loc='best')
    plt.show()
    
    print()
    # best model을 읽어 새로운 자료 분류
    from keras.models import load_model
    mymodel = load_model('cl3model.hdf5')
    new_data = x_test[:5, :]
    print(new_data)
    pred = mymodel.predict(new_data)
    print('pred :', np.where(pred > 0.5, 1, 0).flatten())
    
    
    <console>
        0     1    2    3      4     5     6       7     8     9    10  11  12
    0  7.4  0.70  0.0  1.9  0.076  11.0  34.0  0.9978  3.51  0.56  9.4   5   1
    1  7.8  0.88  0.0  2.6  0.098  25.0  67.0  0.9968  3.20  0.68  9.8   5   1
    <class 'pandas.core.frame.DataFrame'>
    RangeIndex: 6497 entries, 0 to 6496
    Data columns (total 13 columns):
     #   Column  Non-Null Count  Dtype  
    ---  ------  --------------  -----  
     0   0       6497 non-null   float64
     1   1       6497 non-null   float64
     2   2       6497 non-null   float64
     3   3       6497 non-null   float64
     4   4       6497 non-null   float64
     5   5       6497 non-null   float64
     6   6       6497 non-null   float64
     7   7       6497 non-null   float64
     8   8       6497 non-null   float64
     9   9       6497 non-null   float64
     10  10      6497 non-null   float64
     11  11      6497 non-null   int64  
     12  12      6497 non-null   int64  
    dtypes: float64(11), int64(2)
    memory usage: 660.0 KB
    None
    [1 0]
    4898
    1599
    [[ 7.4     0.7     0.      1.9     0.076  11.     34.      0.9978  3.51
       0.56    9.4     5.    ]]
    [1.]
    
    Model: "sequential"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     dense (Dense)               (None, 32)                416       
                                                                     
     dense_1 (Dense)             (None, 16)                528       
                                                                     
     dense_2 (Dense)             (None, 8)                 136       
                                                                     
     dense_3 (Dense)             (None, 1)                 9         
                                                                     
    =================================================================
    Total params: 1,089
    Trainable params: 1,089
    Non-trainable params: 0
    
    loss, acc : 0.0589471161365509 0.9856410026550293
    vloss : [0.20949411392211914, 0.18525467813014984, 0.17078757286071777, 0.16653349995613098, 0.14237532019615173, 0.12865032255649567, 0.14395548403263092, 0.11100532114505768, 0.12268267571926117, 0.1050923764705658, 0.09925971180200577, 0.09010053426027298, 0.08741936832666397, 0.11127576977014542, 0.07937368005514145, 0.07300767302513123, 0.10087704658508301, 0.07097449898719788, 0.06528158485889435, 0.0651187151670456, 0.07779967784881592, 0.06177672743797302, 0.05944723263382912, 0.0974210798740387, 0.07251864671707153, 0.06752219051122665, 0.0567992702126503, 0.07402747124433517, 0.05377015843987465, 0.052300646901130676, 0.059112224727869034, 0.051389697939157486, 0.08173388987779617, 0.07011207938194275, 0.05566557124257088, 0.0589790940284729, 0.049351297318935394, 0.051809050142765045, 0.05481600761413574, 0.05469350889325142, 0.06485271453857422, 0.05285915732383728]
    loss : [0.27672719955444336, 0.21581678092479706, 0.1982879340648651, 0.1879514753818512, 0.17713402211666107, 0.16040614247322083, 0.153812974691391, 0.14045995473861694, 0.13141627609729767, 0.11534695327281952, 0.11378927528858185, 0.10232722759246826, 0.11021049320697784, 0.10000675916671753, 0.08864924311637878, 0.08983210474252701, 0.08327782899141312, 0.08523685485124588, 0.0762864425778389, 0.07394008338451385, 0.08104557543992996, 0.07244405150413513, 0.0693037137389183, 0.06893828511238098, 0.07259487360715866, 0.0741797462105751, 0.06516513973474503, 0.06913289427757263, 0.06532575190067291, 0.06511712074279785, 0.06924497336149216, 0.0667584240436554, 0.06360393017530441, 0.0716901645064354, 0.07001704722642899, 0.06621268391609192, 0.05863720923662186, 0.06977786868810654, 0.062266986817121506, 0.054250918328762054, 0.05558239668607712, 0.06645198911428452]
    vaccuracy : [0.9219779968261719, 0.9384615421295166, 0.9351648092269897, 0.9406593441963196, 0.9516483545303345, 0.9549450278282166, 0.9439560174942017, 0.9560439586639404, 0.9703296422958374, 0.9549450278282166, 0.9560439586639404, 0.9725274443626404, 0.9626373648643494, 0.9714285731315613, 0.9703296422958374, 0.9758241772651672, 0.9615384340286255, 0.9802197813987732, 0.9780219793319702, 0.9769230484962463, 0.9791208505630493, 0.9780219793319702, 0.9813186526298523, 0.9659340381622314, 0.9725274443626404, 0.9747252464294434, 0.9813186526298523, 0.9714285731315613, 0.9802197813987732, 0.9813186526298523, 0.9824175834655762, 0.9835164546966553, 0.9714285731315613, 0.9758241772651672, 0.9835164546966553, 0.9835164546966553, 0.9835164546966553, 0.9846153855323792, 0.9813186526298523, 0.9824175834655762, 0.9791208505630493, 0.9802197813987732]
    accuracy : [0.8897442817687988, 0.9246631860733032, 0.9274126887321472, 0.9320868849754333, 0.9373109936714172, 0.9408853650093079, 0.9466593265533447, 0.9518834352493286, 0.9521583914756775, 0.9593071341514587, 0.9620566368103027, 0.964256227016449, 0.9620566368103027, 0.9659059643745422, 0.971405029296875, 0.9727797508239746, 0.9733296632766724, 0.9722298383712769, 0.9752543568611145, 0.9777289032936096, 0.9725047945976257, 0.9771789908409119, 0.9788287281990051, 0.979103684425354, 0.9774539470672607, 0.9749794006347656, 0.979103684425354, 0.9780038595199585, 0.981853187084198, 0.979653537273407, 0.9785537719726562, 0.9810283184051514, 0.979103684425354, 0.9782788157463074, 0.9788287281990051, 0.9815782308578491, 0.9821281433105469, 0.9785537719726562, 0.9835028648376465, 0.9835028648376465, 0.984602689743042, 0.9780038595199585]
    
    pred : [0 0 0 0 1]

    이항분류의 경우에는 loss='binary_crossentropy', metrics=['accuracy']를 주고, activation='sigmoid' 를 준다.

     

     

     

    loss 시각화

     

    accuracy 시각화

    댓글

Designed by Tistory.