-
TensorFlow 기초 20 - classification(이항분류 wine dataset), 학습 조기 종료, 모델 학습 시 모니터링 결과를 파일로 저장TensorFlow 2022. 12. 2. 17:00
# red&white wine dataset으로 분류 모델 작성 from keras.models import Sequential from keras.layers import Dense from keras.callbacks import EarlyStopping, ModelCheckpoint import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split wdf = pd.read_csv('../testdata/wine.csv', header=None) print(wdf.head(2)) print(wdf.info()) print(wdf.iloc[:, 12].unique()) # [1 0] print(len(wdf[wdf.iloc[:, 12] == 0])) # 4898 print(len(wdf[wdf.iloc[:, 12] == 1])) # 1599 dataset = wdf.values x = dataset[:, 0:12] y = dataset[:, -1] print(x[:1]) print(y[:1]) x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=12) # model model = Sequential() model.add(Dense(units=32, input_dim=12, activation='relu')) model.add(Dense(units=16, activation='relu')) model.add(Dense(units=8, activation='relu')) model.add(Dense(units=1, activation='sigmoid')) print(model.summary()) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # 훈련하지 않고 평가 loss, acc = model.evaluate(x_train, y_train, verbose=2) print('loss, acc :', loss, acc) # 학습 조기 종료 early_stop = EarlyStopping(monitor='val_loss', patience=5) # 모델 학습 시 모니터링 결과를 파일로 저장(validation이 없으면 loss를 건다.) chkpoint = ModelCheckpoint(filepath='cl3model.hdf5', monitor='val_loss', verbose=0, save_best_only=True) history = model.fit(x_train, y_train, epochs=1000, batch_size=32, verbose=2, validation_split=0.2, callbacks=[early_stop, chkpoint]) # 훈련하고 평가 loss, acc = model.evaluate(x_test, y_test, verbose=2) print('loss, acc :', loss, acc) vloss = history.history['val_loss'] loss = history.history['loss'] print('vloss :', vloss) print('loss :', loss) vaccuracy = history.history['val_accuracy'] accuracy = history.history['accuracy'] print('vaccuracy :', vaccuracy) print('accuracy :', accuracy) # 시각화 epoch_len = np.arange(len(accuracy)) plt.plot(epoch_len, vloss, c='red', label='val_loss') plt.plot(epoch_len, loss, c='blue', label='loss') plt.xlabel('epochs') plt.ylabel('loss') plt.legend(loc='best') plt.show() plt.plot(epoch_len, vaccuracy, c='red', label='vaccuracy') plt.plot(epoch_len, accuracy, c='blue', label='accuracy') plt.xlabel('epochs') plt.ylabel('accuracy') plt.legend(loc='best') plt.show() print() # best model을 읽어 새로운 자료 분류 from keras.models import load_model mymodel = load_model('cl3model.hdf5') new_data = x_test[:5, :] print(new_data) pred = mymodel.predict(new_data) print('pred :', np.where(pred > 0.5, 1, 0).flatten()) <console> 0 1 2 3 4 5 6 7 8 9 10 11 12 0 7.4 0.70 0.0 1.9 0.076 11.0 34.0 0.9978 3.51 0.56 9.4 5 1 1 7.8 0.88 0.0 2.6 0.098 25.0 67.0 0.9968 3.20 0.68 9.8 5 1 <class 'pandas.core.frame.DataFrame'> RangeIndex: 6497 entries, 0 to 6496 Data columns (total 13 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 0 6497 non-null float64 1 1 6497 non-null float64 2 2 6497 non-null float64 3 3 6497 non-null float64 4 4 6497 non-null float64 5 5 6497 non-null float64 6 6 6497 non-null float64 7 7 6497 non-null float64 8 8 6497 non-null float64 9 9 6497 non-null float64 10 10 6497 non-null float64 11 11 6497 non-null int64 12 12 6497 non-null int64 dtypes: float64(11), int64(2) memory usage: 660.0 KB None [1 0] 4898 1599 [[ 7.4 0.7 0. 1.9 0.076 11. 34. 0.9978 3.51 0.56 9.4 5. ]] [1.] Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 32) 416 dense_1 (Dense) (None, 16) 528 dense_2 (Dense) (None, 8) 136 dense_3 (Dense) (None, 1) 9 ================================================================= Total params: 1,089 Trainable params: 1,089 Non-trainable params: 0 loss, acc : 0.0589471161365509 0.9856410026550293 vloss : [0.20949411392211914, 0.18525467813014984, 0.17078757286071777, 0.16653349995613098, 0.14237532019615173, 0.12865032255649567, 0.14395548403263092, 0.11100532114505768, 0.12268267571926117, 0.1050923764705658, 0.09925971180200577, 0.09010053426027298, 0.08741936832666397, 0.11127576977014542, 0.07937368005514145, 0.07300767302513123, 0.10087704658508301, 0.07097449898719788, 0.06528158485889435, 0.0651187151670456, 0.07779967784881592, 0.06177672743797302, 0.05944723263382912, 0.0974210798740387, 0.07251864671707153, 0.06752219051122665, 0.0567992702126503, 0.07402747124433517, 0.05377015843987465, 0.052300646901130676, 0.059112224727869034, 0.051389697939157486, 0.08173388987779617, 0.07011207938194275, 0.05566557124257088, 0.0589790940284729, 0.049351297318935394, 0.051809050142765045, 0.05481600761413574, 0.05469350889325142, 0.06485271453857422, 0.05285915732383728] loss : [0.27672719955444336, 0.21581678092479706, 0.1982879340648651, 0.1879514753818512, 0.17713402211666107, 0.16040614247322083, 0.153812974691391, 0.14045995473861694, 0.13141627609729767, 0.11534695327281952, 0.11378927528858185, 0.10232722759246826, 0.11021049320697784, 0.10000675916671753, 0.08864924311637878, 0.08983210474252701, 0.08327782899141312, 0.08523685485124588, 0.0762864425778389, 0.07394008338451385, 0.08104557543992996, 0.07244405150413513, 0.0693037137389183, 0.06893828511238098, 0.07259487360715866, 0.0741797462105751, 0.06516513973474503, 0.06913289427757263, 0.06532575190067291, 0.06511712074279785, 0.06924497336149216, 0.0667584240436554, 0.06360393017530441, 0.0716901645064354, 0.07001704722642899, 0.06621268391609192, 0.05863720923662186, 0.06977786868810654, 0.062266986817121506, 0.054250918328762054, 0.05558239668607712, 0.06645198911428452] vaccuracy : [0.9219779968261719, 0.9384615421295166, 0.9351648092269897, 0.9406593441963196, 0.9516483545303345, 0.9549450278282166, 0.9439560174942017, 0.9560439586639404, 0.9703296422958374, 0.9549450278282166, 0.9560439586639404, 0.9725274443626404, 0.9626373648643494, 0.9714285731315613, 0.9703296422958374, 0.9758241772651672, 0.9615384340286255, 0.9802197813987732, 0.9780219793319702, 0.9769230484962463, 0.9791208505630493, 0.9780219793319702, 0.9813186526298523, 0.9659340381622314, 0.9725274443626404, 0.9747252464294434, 0.9813186526298523, 0.9714285731315613, 0.9802197813987732, 0.9813186526298523, 0.9824175834655762, 0.9835164546966553, 0.9714285731315613, 0.9758241772651672, 0.9835164546966553, 0.9835164546966553, 0.9835164546966553, 0.9846153855323792, 0.9813186526298523, 0.9824175834655762, 0.9791208505630493, 0.9802197813987732] accuracy : [0.8897442817687988, 0.9246631860733032, 0.9274126887321472, 0.9320868849754333, 0.9373109936714172, 0.9408853650093079, 0.9466593265533447, 0.9518834352493286, 0.9521583914756775, 0.9593071341514587, 0.9620566368103027, 0.964256227016449, 0.9620566368103027, 0.9659059643745422, 0.971405029296875, 0.9727797508239746, 0.9733296632766724, 0.9722298383712769, 0.9752543568611145, 0.9777289032936096, 0.9725047945976257, 0.9771789908409119, 0.9788287281990051, 0.979103684425354, 0.9774539470672607, 0.9749794006347656, 0.979103684425354, 0.9780038595199585, 0.981853187084198, 0.979653537273407, 0.9785537719726562, 0.9810283184051514, 0.979103684425354, 0.9782788157463074, 0.9788287281990051, 0.9815782308578491, 0.9821281433105469, 0.9785537719726562, 0.9835028648376465, 0.9835028648376465, 0.984602689743042, 0.9780038595199585] pred : [0 0 0 0 1]
이항분류의 경우에는 loss='binary_crossentropy', metrics=['accuracy']를 주고, activation='sigmoid' 를 준다.
loss 시각화 accuracy 시각화 'TensorFlow' 카테고리의 다른 글
TensorFlow 기초 22 - zoo animal dataset으로 동물의 type을 7가지로 분류(다항분) (0) 2022.12.06 TensorFlow 기초 21 - diabetes 데이터로 이항분류(sigmoid)와 다항분류(softmax) 처리 (0) 2022.12.06 TensorFlow 기초 19 - classification (0) 2022.12.02 TensorFlow 기초 18 - 다항회귀(Polynomial Regression) (0) 2022.12.02 TensorFlow 기초 17 - 현대차 가격예측 모델(function api 사용 방법, GradientTape 객체 사용 방법) (0) 2022.12.02