# 케라스에서 제공하는 로이터 뉴스 데이터를 LSTM을 이용하여 텍스트 분류를 진행해보겠습니다.
# 로이터 뉴스 기사 데이터는 총 11,258개의 뉴스 기사가 46개의 뉴스 카테고리로 분류되는 뉴스 기사 데이터입니다.
from keras.datasets import reuters
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Embedding, Flatten
from keras.utils import pad_sequences
import matplotlib.pyplot as plt
max_features = 10000
(x_train, y_train), (x_test, y_test) = reuters.load_data(num_words=max_features)
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape) # (8982,) (8982,) (2246,) (2246,)
print(len(set(y_train)))
print(x_train[100])
print(y_train[100])
# train / validation split
x_val = x_train[7000:]
y_val = y_train[7000:]
x_train = x_train[:7000]
y_train = y_train[:7000]
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
# 문장 길이 맞추기
text_max_words = 120
x_train = pad_sequences(x_train, maxlen=text_max_words)
x_val = pad_sequences(x_val, maxlen=text_max_words)
x_test = pad_sequences(x_test, maxlen=text_max_words)
print(x_train[0], len(x_train[0]))
# one-hot
y_train = np_utils.to_categorical(y_train)
y_val = np_utils.to_categorical(y_val)
y_test = np_utils.to_categorical(y_test)
print(y_train[0])
# 모델 구성 방법 1 : Dense로만 구성
model = Sequential()
model.add(Embedding(max_features, 128, input_length=text_max_words))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(46, activation='softmax'))
print(model.summary()) # Trainable params: 5,224,238
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
hist = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_val, y_val), verbose=2)
# 시각화
def plt_func():
fig, loss_ax = plt.subplots()
acc_ax = loss_ax.twinx()
loss_ax.plot(hist.history['loss'], c='y', label='train loss')
loss_ax.plot(hist.history['val_loss'], c='r', label='val loss')
loss_ax.set_ylim([0,3])
acc_ax.plot(hist.history['accuracy'], c='y', label='train accuracy')
acc_ax.plot(hist.history['val_accuracy'], c='r', label='val accuracy')
acc_ax.set_ylim([0,1])
loss_ax.legend()
acc_ax.legend()
plt.show()
loss_metrics = model.evaluate(x_test, y_test, batch_size=64)
print('loss_metrics :', loss_metrics)
plt_func()
<console>
(8982,) (8982,) (2246,) (2246,)
46
[1, 367, 1394, 169, 65, 87, 209, 30, 306, 228, 10, 803, 305, 96, 5, 196, 15, 10, 523, 2, 3006, 293, 484, 2, 1440, 5825, 8, 145, 7, 10, 1670, 6, 10, 294, 517, 237, 2, 367, 8042, 7, 2477, 1177, 483, 1440, 5825, 8, 367, 1394, 4, 169, 387, 66, 209, 30, 2344, 652, 1496, 9, 209, 30, 2564, 228, 10, 803, 305, 96, 5, 196, 15, 51, 36, 1457, 24, 1345, 5, 4, 196, 150, 10, 523, 320, 64, 992, 6373, 13, 367, 190, 297, 64, 85, 1692, 6, 8656, 122, 9, 36, 1457, 24, 269, 4753, 27, 367, 212, 114, 45, 30, 3292, 7, 126, 2203, 13, 367, 6, 1818, 4, 169, 65, 96, 28, 432, 23, 189, 1254, 4, 9725, 320, 5, 196, 15, 10, 523, 25, 730, 190, 57, 64, 6, 9953, 2016, 6373, 7, 2, 122, 1440, 5825, 8, 269, 4753, 1217, 7, 608, 2203, 30, 3292, 1440, 5825, 8, 43, 339, 43, 231, 9, 667, 1820, 126, 212, 4197, 21, 1709, 249, 311, 13, 260, 489, 9, 65, 4753, 64, 1209, 4397, 249, 954, 36, 152, 1440, 5825, 506, 24, 135, 367, 311, 34, 420, 4, 8407, 200, 1519, 13, 137, 730, 190, 7, 104, 570, 52, 64, 2492, 7725, 4, 642, 5, 405, 7725, 2492, 24, 76, 847, 1435, 4446, 6, 10, 548, 320, 34, 325, 136, 694, 1440, 5825, 8, 10, 5184, 847, 7, 4, 169, 76, 2378, 10, 4933, 3447, 5, 141, 1082, 36, 152, 36, 8, 126, 358, 367, 65, 814, 190, 64, 2575, 10, 969, 3161, 92, 48, 6, 2245, 31, 367, 51, 570, 4753, 292, 27, 405, 212, 62, 3740, 922, 9, 2464, 27, 367, 77, 62, 4397, 7, 316, 5, 874, 36, 152, 4, 936, 1243, 5, 358, 367, 398, 57, 45, 3680, 7367, 6, 2394, 1343, 13, 373, 4504, 36, 8, 1440, 5825, 8, 42, 196, 150, 10, 523, 96, 34, 9725, 43, 16, 1261, 205, 7, 4, 65, 182, 1351, 367, 6, 351, 184, 45, 6081, 2286, 197, 1245, 13, 3187, 2, 274, 419, 714, 1351, 367, 269, 10, 96, 41, 129, 1104, 1673, 1419, 578, 36, 152, 2, 1440, 7615, 367, 1683, 484, 293, 75, 6557, 4, 8042, 152, 24, 5222, 34, 325, 834, 6, 1356, 2, 2406, 7, 4, 65, 76, 1082, 164, 1574, 212, 9, 861, 34, 8340, 13, 286, 1930, 1440, 7615, 8, 787, 36, 1830, 1082, 41, 3751, 616, 6, 382, 2, 2, 1574, 6928, 17, 12]
20
(7000,) (7000,) (1982,) (1982,)
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 1 2 2 8 43 10 447 5 25
207 270 5 3095 111 16 369 186 90 67 7 89 5 19
102 6 19 124 15 90 67 84 22 482 26 7 48 4
49 8 864 39 209 154 6 151 6 83 11 15 22 155
11 15 7 48 9 4579 1005 504 6 258 6 272 11 15
22 134 44 11 15 16 8 197 1245 90 67 52 29 209
30 32 132 6 109 15 17 12] 120
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 120, 128) 1280000
flatten (Flatten) (None, 15360) 0
dense (Dense) (None, 256) 3932416
dense_1 (Dense) (None, 46) 11822
=================================================================
Total params: 5,224,238
Trainable params: 5,224,238
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/10
110/110 - 3s - loss: 1.9523 - accuracy: 0.5139 - val_loss: 1.4971 - val_accuracy: 0.6372 - 3s/epoch - 31ms/step
Epoch 2/10
110/110 - 3s - loss: 0.9231 - accuracy: 0.7911 - val_loss: 1.2598 - val_accuracy: 0.7094 - 3s/epoch - 27ms/step
Epoch 3/10
110/110 - 3s - loss: 0.3035 - accuracy: 0.9439 - val_loss: 1.3157 - val_accuracy: 0.7013 - 3s/epoch - 27ms/step
Epoch 4/10
110/110 - 3s - loss: 0.1877 - accuracy: 0.9600 - val_loss: 1.3787 - val_accuracy: 0.6958 - 3s/epoch - 27ms/step
Epoch 5/10
110/110 - 3s - loss: 0.1563 - accuracy: 0.9616 - val_loss: 1.3500 - val_accuracy: 0.7053 - 3s/epoch - 28ms/step
Epoch 6/10
110/110 - 3s - loss: 0.1321 - accuracy: 0.9630 - val_loss: 1.4119 - val_accuracy: 0.6902 - 3s/epoch - 27ms/step
Epoch 7/10
110/110 - 3s - loss: 0.1281 - accuracy: 0.9620 - val_loss: 1.4077 - val_accuracy: 0.6932 - 3s/epoch - 27ms/step
Epoch 8/10
110/110 - 3s - loss: 0.1126 - accuracy: 0.9649 - val_loss: 1.3556 - val_accuracy: 0.7018 - 3s/epoch - 27ms/step
Epoch 9/10
110/110 - 3s - loss: 0.1085 - accuracy: 0.9633 - val_loss: 1.4004 - val_accuracy: 0.6932 - 3s/epoch - 27ms/step
Epoch 10/10
110/110 - 3s - loss: 0.0953 - accuracy: 0.9644 - val_loss: 1.4192 - val_accuracy: 0.6993 - 3s/epoch - 27ms/step
1/36 [..............................] - ETA: 0s - loss: 1.7320 - accuracy: 0.6406
17/36 [=============>................] - ETA: 0s - loss: 1.4034 - accuracy: 0.6939
32/36 [=========================>....] - ETA: 0s - loss: 1.4222 - accuracy: 0.6870
36/36 [==============================] - 0s 3ms/step - loss: 1.4476 - accuracy: 0.6861
loss_metrics : [1.4476467370986938, 0.6861086487770081]