TensorFlow

TensorFlow 기초 20 - classification(이항분류 wine dataset), 학습 조기 종료, 모델 학습 시 모니터링 결과를 파일로 저장

코딩탕탕 2022. 12. 2. 17:00

 

 

# red&white wine dataset으로 분류 모델 작성

from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping, ModelCheckpoint
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

wdf = pd.read_csv('../testdata/wine.csv', header=None)
print(wdf.head(2))
print(wdf.info())
print(wdf.iloc[:, 12].unique()) # [1 0]
print(len(wdf[wdf.iloc[:, 12] == 0])) # 4898
print(len(wdf[wdf.iloc[:, 12] == 1])) # 1599

dataset = wdf.values
x = dataset[:, 0:12]
y = dataset[:, -1]
print(x[:1])
print(y[:1])

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=12)

# model
model = Sequential()
model.add(Dense(units=32, input_dim=12, activation='relu'))
model.add(Dense(units=16, activation='relu'))
model.add(Dense(units=8, activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))
print(model.summary())

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 훈련하지 않고 평가
loss, acc = model.evaluate(x_train, y_train, verbose=2)
print('loss, acc :', loss, acc)

# 학습 조기 종료
early_stop = EarlyStopping(monitor='val_loss', patience=5)

# 모델 학습 시 모니터링 결과를 파일로 저장(validation이 없으면 loss를 건다.)
chkpoint = ModelCheckpoint(filepath='cl3model.hdf5', monitor='val_loss', verbose=0, save_best_only=True)

history = model.fit(x_train, y_train, epochs=1000, batch_size=32, verbose=2,
                    validation_split=0.2,
                    callbacks=[early_stop, chkpoint])

# 훈련하고 평가
loss, acc = model.evaluate(x_test, y_test, verbose=2)
print('loss, acc :', loss, acc)

vloss = history.history['val_loss']
loss = history.history['loss']
print('vloss :', vloss)
print('loss :', loss)

vaccuracy = history.history['val_accuracy']
accuracy = history.history['accuracy']
print('vaccuracy :', vaccuracy)
print('accuracy :', accuracy)

# 시각화
epoch_len = np.arange(len(accuracy))

plt.plot(epoch_len, vloss, c='red', label='val_loss')
plt.plot(epoch_len, loss, c='blue', label='loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend(loc='best')
plt.show()

plt.plot(epoch_len, vaccuracy, c='red', label='vaccuracy')
plt.plot(epoch_len, accuracy, c='blue', label='accuracy')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend(loc='best')
plt.show()

print()
# best model을 읽어 새로운 자료 분류
from keras.models import load_model
mymodel = load_model('cl3model.hdf5')
new_data = x_test[:5, :]
print(new_data)
pred = mymodel.predict(new_data)
print('pred :', np.where(pred > 0.5, 1, 0).flatten())


<console>
    0     1    2    3      4     5     6       7     8     9    10  11  12
0  7.4  0.70  0.0  1.9  0.076  11.0  34.0  0.9978  3.51  0.56  9.4   5   1
1  7.8  0.88  0.0  2.6  0.098  25.0  67.0  0.9968  3.20  0.68  9.8   5   1
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6497 entries, 0 to 6496
Data columns (total 13 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   0       6497 non-null   float64
 1   1       6497 non-null   float64
 2   2       6497 non-null   float64
 3   3       6497 non-null   float64
 4   4       6497 non-null   float64
 5   5       6497 non-null   float64
 6   6       6497 non-null   float64
 7   7       6497 non-null   float64
 8   8       6497 non-null   float64
 9   9       6497 non-null   float64
 10  10      6497 non-null   float64
 11  11      6497 non-null   int64  
 12  12      6497 non-null   int64  
dtypes: float64(11), int64(2)
memory usage: 660.0 KB
None
[1 0]
4898
1599
[[ 7.4     0.7     0.      1.9     0.076  11.     34.      0.9978  3.51
   0.56    9.4     5.    ]]
[1.]

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 32)                416       
                                                                 
 dense_1 (Dense)             (None, 16)                528       
                                                                 
 dense_2 (Dense)             (None, 8)                 136       
                                                                 
 dense_3 (Dense)             (None, 1)                 9         
                                                                 
=================================================================
Total params: 1,089
Trainable params: 1,089
Non-trainable params: 0

loss, acc : 0.0589471161365509 0.9856410026550293
vloss : [0.20949411392211914, 0.18525467813014984, 0.17078757286071777, 0.16653349995613098, 0.14237532019615173, 0.12865032255649567, 0.14395548403263092, 0.11100532114505768, 0.12268267571926117, 0.1050923764705658, 0.09925971180200577, 0.09010053426027298, 0.08741936832666397, 0.11127576977014542, 0.07937368005514145, 0.07300767302513123, 0.10087704658508301, 0.07097449898719788, 0.06528158485889435, 0.0651187151670456, 0.07779967784881592, 0.06177672743797302, 0.05944723263382912, 0.0974210798740387, 0.07251864671707153, 0.06752219051122665, 0.0567992702126503, 0.07402747124433517, 0.05377015843987465, 0.052300646901130676, 0.059112224727869034, 0.051389697939157486, 0.08173388987779617, 0.07011207938194275, 0.05566557124257088, 0.0589790940284729, 0.049351297318935394, 0.051809050142765045, 0.05481600761413574, 0.05469350889325142, 0.06485271453857422, 0.05285915732383728]
loss : [0.27672719955444336, 0.21581678092479706, 0.1982879340648651, 0.1879514753818512, 0.17713402211666107, 0.16040614247322083, 0.153812974691391, 0.14045995473861694, 0.13141627609729767, 0.11534695327281952, 0.11378927528858185, 0.10232722759246826, 0.11021049320697784, 0.10000675916671753, 0.08864924311637878, 0.08983210474252701, 0.08327782899141312, 0.08523685485124588, 0.0762864425778389, 0.07394008338451385, 0.08104557543992996, 0.07244405150413513, 0.0693037137389183, 0.06893828511238098, 0.07259487360715866, 0.0741797462105751, 0.06516513973474503, 0.06913289427757263, 0.06532575190067291, 0.06511712074279785, 0.06924497336149216, 0.0667584240436554, 0.06360393017530441, 0.0716901645064354, 0.07001704722642899, 0.06621268391609192, 0.05863720923662186, 0.06977786868810654, 0.062266986817121506, 0.054250918328762054, 0.05558239668607712, 0.06645198911428452]
vaccuracy : [0.9219779968261719, 0.9384615421295166, 0.9351648092269897, 0.9406593441963196, 0.9516483545303345, 0.9549450278282166, 0.9439560174942017, 0.9560439586639404, 0.9703296422958374, 0.9549450278282166, 0.9560439586639404, 0.9725274443626404, 0.9626373648643494, 0.9714285731315613, 0.9703296422958374, 0.9758241772651672, 0.9615384340286255, 0.9802197813987732, 0.9780219793319702, 0.9769230484962463, 0.9791208505630493, 0.9780219793319702, 0.9813186526298523, 0.9659340381622314, 0.9725274443626404, 0.9747252464294434, 0.9813186526298523, 0.9714285731315613, 0.9802197813987732, 0.9813186526298523, 0.9824175834655762, 0.9835164546966553, 0.9714285731315613, 0.9758241772651672, 0.9835164546966553, 0.9835164546966553, 0.9835164546966553, 0.9846153855323792, 0.9813186526298523, 0.9824175834655762, 0.9791208505630493, 0.9802197813987732]
accuracy : [0.8897442817687988, 0.9246631860733032, 0.9274126887321472, 0.9320868849754333, 0.9373109936714172, 0.9408853650093079, 0.9466593265533447, 0.9518834352493286, 0.9521583914756775, 0.9593071341514587, 0.9620566368103027, 0.964256227016449, 0.9620566368103027, 0.9659059643745422, 0.971405029296875, 0.9727797508239746, 0.9733296632766724, 0.9722298383712769, 0.9752543568611145, 0.9777289032936096, 0.9725047945976257, 0.9771789908409119, 0.9788287281990051, 0.979103684425354, 0.9774539470672607, 0.9749794006347656, 0.979103684425354, 0.9780038595199585, 0.981853187084198, 0.979653537273407, 0.9785537719726562, 0.9810283184051514, 0.979103684425354, 0.9782788157463074, 0.9788287281990051, 0.9815782308578491, 0.9821281433105469, 0.9785537719726562, 0.9835028648376465, 0.9835028648376465, 0.984602689743042, 0.9780038595199585]

pred : [0 0 0 0 1]

이항분류의 경우에는 loss='binary_crossentropy', metrics=['accuracy']를 주고, activation='sigmoid' 를 준다.

 

 

 

loss 시각화

 

accuracy 시각화