TensorFlow
TensorFlow 기초 20 - classification(이항분류 wine dataset), 학습 조기 종료, 모델 학습 시 모니터링 결과를 파일로 저장
코딩탕탕
2022. 12. 2. 17:00
# red&white wine dataset으로 분류 모델 작성
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping, ModelCheckpoint
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
wdf = pd.read_csv('../testdata/wine.csv', header=None)
print(wdf.head(2))
print(wdf.info())
print(wdf.iloc[:, 12].unique()) # [1 0]
print(len(wdf[wdf.iloc[:, 12] == 0])) # 4898
print(len(wdf[wdf.iloc[:, 12] == 1])) # 1599
dataset = wdf.values
x = dataset[:, 0:12]
y = dataset[:, -1]
print(x[:1])
print(y[:1])
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=12)
# model
model = Sequential()
model.add(Dense(units=32, input_dim=12, activation='relu'))
model.add(Dense(units=16, activation='relu'))
model.add(Dense(units=8, activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))
print(model.summary())
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 훈련하지 않고 평가
loss, acc = model.evaluate(x_train, y_train, verbose=2)
print('loss, acc :', loss, acc)
# 학습 조기 종료
early_stop = EarlyStopping(monitor='val_loss', patience=5)
# 모델 학습 시 모니터링 결과를 파일로 저장(validation이 없으면 loss를 건다.)
chkpoint = ModelCheckpoint(filepath='cl3model.hdf5', monitor='val_loss', verbose=0, save_best_only=True)
history = model.fit(x_train, y_train, epochs=1000, batch_size=32, verbose=2,
validation_split=0.2,
callbacks=[early_stop, chkpoint])
# 훈련하고 평가
loss, acc = model.evaluate(x_test, y_test, verbose=2)
print('loss, acc :', loss, acc)
vloss = history.history['val_loss']
loss = history.history['loss']
print('vloss :', vloss)
print('loss :', loss)
vaccuracy = history.history['val_accuracy']
accuracy = history.history['accuracy']
print('vaccuracy :', vaccuracy)
print('accuracy :', accuracy)
# 시각화
epoch_len = np.arange(len(accuracy))
plt.plot(epoch_len, vloss, c='red', label='val_loss')
plt.plot(epoch_len, loss, c='blue', label='loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend(loc='best')
plt.show()
plt.plot(epoch_len, vaccuracy, c='red', label='vaccuracy')
plt.plot(epoch_len, accuracy, c='blue', label='accuracy')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend(loc='best')
plt.show()
print()
# best model을 읽어 새로운 자료 분류
from keras.models import load_model
mymodel = load_model('cl3model.hdf5')
new_data = x_test[:5, :]
print(new_data)
pred = mymodel.predict(new_data)
print('pred :', np.where(pred > 0.5, 1, 0).flatten())
<console>
0 1 2 3 4 5 6 7 8 9 10 11 12
0 7.4 0.70 0.0 1.9 0.076 11.0 34.0 0.9978 3.51 0.56 9.4 5 1
1 7.8 0.88 0.0 2.6 0.098 25.0 67.0 0.9968 3.20 0.68 9.8 5 1
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6497 entries, 0 to 6496
Data columns (total 13 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 0 6497 non-null float64
1 1 6497 non-null float64
2 2 6497 non-null float64
3 3 6497 non-null float64
4 4 6497 non-null float64
5 5 6497 non-null float64
6 6 6497 non-null float64
7 7 6497 non-null float64
8 8 6497 non-null float64
9 9 6497 non-null float64
10 10 6497 non-null float64
11 11 6497 non-null int64
12 12 6497 non-null int64
dtypes: float64(11), int64(2)
memory usage: 660.0 KB
None
[1 0]
4898
1599
[[ 7.4 0.7 0. 1.9 0.076 11. 34. 0.9978 3.51
0.56 9.4 5. ]]
[1.]
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 32) 416
dense_1 (Dense) (None, 16) 528
dense_2 (Dense) (None, 8) 136
dense_3 (Dense) (None, 1) 9
=================================================================
Total params: 1,089
Trainable params: 1,089
Non-trainable params: 0
loss, acc : 0.0589471161365509 0.9856410026550293
vloss : [0.20949411392211914, 0.18525467813014984, 0.17078757286071777, 0.16653349995613098, 0.14237532019615173, 0.12865032255649567, 0.14395548403263092, 0.11100532114505768, 0.12268267571926117, 0.1050923764705658, 0.09925971180200577, 0.09010053426027298, 0.08741936832666397, 0.11127576977014542, 0.07937368005514145, 0.07300767302513123, 0.10087704658508301, 0.07097449898719788, 0.06528158485889435, 0.0651187151670456, 0.07779967784881592, 0.06177672743797302, 0.05944723263382912, 0.0974210798740387, 0.07251864671707153, 0.06752219051122665, 0.0567992702126503, 0.07402747124433517, 0.05377015843987465, 0.052300646901130676, 0.059112224727869034, 0.051389697939157486, 0.08173388987779617, 0.07011207938194275, 0.05566557124257088, 0.0589790940284729, 0.049351297318935394, 0.051809050142765045, 0.05481600761413574, 0.05469350889325142, 0.06485271453857422, 0.05285915732383728]
loss : [0.27672719955444336, 0.21581678092479706, 0.1982879340648651, 0.1879514753818512, 0.17713402211666107, 0.16040614247322083, 0.153812974691391, 0.14045995473861694, 0.13141627609729767, 0.11534695327281952, 0.11378927528858185, 0.10232722759246826, 0.11021049320697784, 0.10000675916671753, 0.08864924311637878, 0.08983210474252701, 0.08327782899141312, 0.08523685485124588, 0.0762864425778389, 0.07394008338451385, 0.08104557543992996, 0.07244405150413513, 0.0693037137389183, 0.06893828511238098, 0.07259487360715866, 0.0741797462105751, 0.06516513973474503, 0.06913289427757263, 0.06532575190067291, 0.06511712074279785, 0.06924497336149216, 0.0667584240436554, 0.06360393017530441, 0.0716901645064354, 0.07001704722642899, 0.06621268391609192, 0.05863720923662186, 0.06977786868810654, 0.062266986817121506, 0.054250918328762054, 0.05558239668607712, 0.06645198911428452]
vaccuracy : [0.9219779968261719, 0.9384615421295166, 0.9351648092269897, 0.9406593441963196, 0.9516483545303345, 0.9549450278282166, 0.9439560174942017, 0.9560439586639404, 0.9703296422958374, 0.9549450278282166, 0.9560439586639404, 0.9725274443626404, 0.9626373648643494, 0.9714285731315613, 0.9703296422958374, 0.9758241772651672, 0.9615384340286255, 0.9802197813987732, 0.9780219793319702, 0.9769230484962463, 0.9791208505630493, 0.9780219793319702, 0.9813186526298523, 0.9659340381622314, 0.9725274443626404, 0.9747252464294434, 0.9813186526298523, 0.9714285731315613, 0.9802197813987732, 0.9813186526298523, 0.9824175834655762, 0.9835164546966553, 0.9714285731315613, 0.9758241772651672, 0.9835164546966553, 0.9835164546966553, 0.9835164546966553, 0.9846153855323792, 0.9813186526298523, 0.9824175834655762, 0.9791208505630493, 0.9802197813987732]
accuracy : [0.8897442817687988, 0.9246631860733032, 0.9274126887321472, 0.9320868849754333, 0.9373109936714172, 0.9408853650093079, 0.9466593265533447, 0.9518834352493286, 0.9521583914756775, 0.9593071341514587, 0.9620566368103027, 0.964256227016449, 0.9620566368103027, 0.9659059643745422, 0.971405029296875, 0.9727797508239746, 0.9733296632766724, 0.9722298383712769, 0.9752543568611145, 0.9777289032936096, 0.9725047945976257, 0.9771789908409119, 0.9788287281990051, 0.979103684425354, 0.9774539470672607, 0.9749794006347656, 0.979103684425354, 0.9780038595199585, 0.981853187084198, 0.979653537273407, 0.9785537719726562, 0.9810283184051514, 0.979103684425354, 0.9782788157463074, 0.9788287281990051, 0.9815782308578491, 0.9821281433105469, 0.9785537719726562, 0.9835028648376465, 0.9835028648376465, 0.984602689743042, 0.9780038595199585]
pred : [0 0 0 0 1]
이항분류의 경우에는 loss='binary_crossentropy', metrics=['accuracy']를 주고, activation='sigmoid' 를 준다.

