ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • TensorFlow 기초 43 - 케라스에서 제공하는 로이터 뉴스 데이터를 LSTM을 이용하여 텍스트 분류를 진행
    TensorFlow 2022. 12. 19. 13:20

     

    # 케라스에서 제공하는 로이터 뉴스 데이터를 LSTM을 이용하여 텍스트 분류를 진행해보겠습니다.
    # 로이터 뉴스 기사 데이터는 총 11,258개의 뉴스 기사가 46개의 뉴스 카테고리로 분류되는 뉴스 기사 데이터입니다.
    
    from keras.datasets import reuters
    from keras.utils import np_utils
    from keras.models import Sequential
    from keras.layers import Dense, Embedding, Flatten
    from keras.utils import pad_sequences
    import matplotlib.pyplot as plt
    
    max_features = 10000
    
    (x_train, y_train), (x_test, y_test) = reuters.load_data(num_words=max_features)
    print(x_train.shape, y_train.shape, x_test.shape, y_test.shape) # (8982,) (8982,) (2246,) (2246,)
    print(len(set(y_train)))
    print(x_train[100])
    print(y_train[100])
    
    # train / validation split
    x_val = x_train[7000:]
    y_val = y_train[7000:]
    x_train = x_train[:7000]
    y_train = y_train[:7000]
    
    print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
    
    # 문장 길이 맞추기
    text_max_words = 120
    x_train = pad_sequences(x_train, maxlen=text_max_words)
    x_val = pad_sequences(x_val, maxlen=text_max_words)
    x_test = pad_sequences(x_test, maxlen=text_max_words)
    print(x_train[0], len(x_train[0]))
    
    # one-hot
    y_train = np_utils.to_categorical(y_train)
    y_val = np_utils.to_categorical(y_val)
    y_test = np_utils.to_categorical(y_test)
    print(y_train[0])
    
    # 모델 구성 방법 1 : Dense로만 구성
    model = Sequential()
    model.add(Embedding(max_features, 128, input_length=text_max_words))
    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dense(46, activation='softmax'))
    print(model.summary()) # Trainable params: 5,224,238
    
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    
    hist = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_val, y_val), verbose=2)
    
    # 시각화
    def plt_func():
        fig, loss_ax = plt.subplots()
        acc_ax = loss_ax.twinx()
        loss_ax.plot(hist.history['loss'], c='y', label='train loss')
        loss_ax.plot(hist.history['val_loss'], c='r', label='val loss')
        loss_ax.set_ylim([0,3])
        
        acc_ax.plot(hist.history['accuracy'], c='y', label='train accuracy')
        acc_ax.plot(hist.history['val_accuracy'], c='r', label='val accuracy')
        acc_ax.set_ylim([0,1])
    
        loss_ax.legend()
        acc_ax.legend()
    
        plt.show()
        
        loss_metrics = model.evaluate(x_test, y_test, batch_size=64)
        print('loss_metrics :', loss_metrics)
    
    plt_func()
    
    
    <console>
    (8982,) (8982,) (2246,) (2246,)
    46
    [1, 367, 1394, 169, 65, 87, 209, 30, 306, 228, 10, 803, 305, 96, 5, 196, 15, 10, 523, 2, 3006, 293, 484, 2, 1440, 5825, 8, 145, 7, 10, 1670, 6, 10, 294, 517, 237, 2, 367, 8042, 7, 2477, 1177, 483, 1440, 5825, 8, 367, 1394, 4, 169, 387, 66, 209, 30, 2344, 652, 1496, 9, 209, 30, 2564, 228, 10, 803, 305, 96, 5, 196, 15, 51, 36, 1457, 24, 1345, 5, 4, 196, 150, 10, 523, 320, 64, 992, 6373, 13, 367, 190, 297, 64, 85, 1692, 6, 8656, 122, 9, 36, 1457, 24, 269, 4753, 27, 367, 212, 114, 45, 30, 3292, 7, 126, 2203, 13, 367, 6, 1818, 4, 169, 65, 96, 28, 432, 23, 189, 1254, 4, 9725, 320, 5, 196, 15, 10, 523, 25, 730, 190, 57, 64, 6, 9953, 2016, 6373, 7, 2, 122, 1440, 5825, 8, 269, 4753, 1217, 7, 608, 2203, 30, 3292, 1440, 5825, 8, 43, 339, 43, 231, 9, 667, 1820, 126, 212, 4197, 21, 1709, 249, 311, 13, 260, 489, 9, 65, 4753, 64, 1209, 4397, 249, 954, 36, 152, 1440, 5825, 506, 24, 135, 367, 311, 34, 420, 4, 8407, 200, 1519, 13, 137, 730, 190, 7, 104, 570, 52, 64, 2492, 7725, 4, 642, 5, 405, 7725, 2492, 24, 76, 847, 1435, 4446, 6, 10, 548, 320, 34, 325, 136, 694, 1440, 5825, 8, 10, 5184, 847, 7, 4, 169, 76, 2378, 10, 4933, 3447, 5, 141, 1082, 36, 152, 36, 8, 126, 358, 367, 65, 814, 190, 64, 2575, 10, 969, 3161, 92, 48, 6, 2245, 31, 367, 51, 570, 4753, 292, 27, 405, 212, 62, 3740, 922, 9, 2464, 27, 367, 77, 62, 4397, 7, 316, 5, 874, 36, 152, 4, 936, 1243, 5, 358, 367, 398, 57, 45, 3680, 7367, 6, 2394, 1343, 13, 373, 4504, 36, 8, 1440, 5825, 8, 42, 196, 150, 10, 523, 96, 34, 9725, 43, 16, 1261, 205, 7, 4, 65, 182, 1351, 367, 6, 351, 184, 45, 6081, 2286, 197, 1245, 13, 3187, 2, 274, 419, 714, 1351, 367, 269, 10, 96, 41, 129, 1104, 1673, 1419, 578, 36, 152, 2, 1440, 7615, 367, 1683, 484, 293, 75, 6557, 4, 8042, 152, 24, 5222, 34, 325, 834, 6, 1356, 2, 2406, 7, 4, 65, 76, 1082, 164, 1574, 212, 9, 861, 34, 8340, 13, 286, 1930, 1440, 7615, 8, 787, 36, 1830, 1082, 41, 3751, 616, 6, 382, 2, 2, 1574, 6928, 17, 12]
    20
    (7000,) (7000,) (1982,) (1982,)
    [   0    0    0    0    0    0    0    0    0    0    0    0    0    0
        0    0    0    0    0    0    0    0    0    0    0    0    0    0
        0    0    0    0    0    1    2    2    8   43   10  447    5   25
      207  270    5 3095  111   16  369  186   90   67    7   89    5   19
      102    6   19  124   15   90   67   84   22  482   26    7   48    4
       49    8  864   39  209  154    6  151    6   83   11   15   22  155
       11   15    7   48    9 4579 1005  504    6  258    6  272   11   15
       22  134   44   11   15   16    8  197 1245   90   67   52   29  209
       30   32  132    6  109   15   17   12] 120
    [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
     0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
    
    Model: "sequential"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     embedding (Embedding)       (None, 120, 128)          1280000   
                                                                     
     flatten (Flatten)           (None, 15360)             0         
                                                                     
     dense (Dense)               (None, 256)               3932416   
                                                                     
     dense_1 (Dense)             (None, 46)                11822     
                                                                     
    =================================================================
    Total params: 5,224,238
    Trainable params: 5,224,238
    Non-trainable params: 0
    _________________________________________________________________
    None
    Epoch 1/10
    110/110 - 3s - loss: 1.9523 - accuracy: 0.5139 - val_loss: 1.4971 - val_accuracy: 0.6372 - 3s/epoch - 31ms/step
    Epoch 2/10
    110/110 - 3s - loss: 0.9231 - accuracy: 0.7911 - val_loss: 1.2598 - val_accuracy: 0.7094 - 3s/epoch - 27ms/step
    Epoch 3/10
    110/110 - 3s - loss: 0.3035 - accuracy: 0.9439 - val_loss: 1.3157 - val_accuracy: 0.7013 - 3s/epoch - 27ms/step
    Epoch 4/10
    110/110 - 3s - loss: 0.1877 - accuracy: 0.9600 - val_loss: 1.3787 - val_accuracy: 0.6958 - 3s/epoch - 27ms/step
    Epoch 5/10
    110/110 - 3s - loss: 0.1563 - accuracy: 0.9616 - val_loss: 1.3500 - val_accuracy: 0.7053 - 3s/epoch - 28ms/step
    Epoch 6/10
    110/110 - 3s - loss: 0.1321 - accuracy: 0.9630 - val_loss: 1.4119 - val_accuracy: 0.6902 - 3s/epoch - 27ms/step
    Epoch 7/10
    110/110 - 3s - loss: 0.1281 - accuracy: 0.9620 - val_loss: 1.4077 - val_accuracy: 0.6932 - 3s/epoch - 27ms/step
    Epoch 8/10
    110/110 - 3s - loss: 0.1126 - accuracy: 0.9649 - val_loss: 1.3556 - val_accuracy: 0.7018 - 3s/epoch - 27ms/step
    Epoch 9/10
    110/110 - 3s - loss: 0.1085 - accuracy: 0.9633 - val_loss: 1.4004 - val_accuracy: 0.6932 - 3s/epoch - 27ms/step
    Epoch 10/10
    110/110 - 3s - loss: 0.0953 - accuracy: 0.9644 - val_loss: 1.4192 - val_accuracy: 0.6993 - 3s/epoch - 27ms/step
    
     1/36 [..............................] - ETA: 0s - loss: 1.7320 - accuracy: 0.6406
    17/36 [=============>................] - ETA: 0s - loss: 1.4034 - accuracy: 0.6939
    32/36 [=========================>....] - ETA: 0s - loss: 1.4222 - accuracy: 0.6870
    36/36 [==============================] - 0s 3ms/step - loss: 1.4476 - accuracy: 0.6861
    loss_metrics : [1.4476467370986938, 0.6861086487770081]

     

     

     

    댓글

Designed by Tistory.