ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • XGBoost로 분류 모델 예시(kaggle.com이 제공하는 'glass datasets')
    Python 데이터 분석 2022. 11. 23. 14:40

     

    <작성자 코드>

    # [XGBoost 문제] 
    # kaggle.com이 제공하는 'glass datasets'
    # 유리 식별 데이터베이스로 여러 가지 특징들에 의해 7 가지의 label(Type)로 분리된다.
    #
    # RI    Na    Mg    Al    Si    K    Ca    Ba    Fe    Type
    #                           ...
    # glass.csv 파일을 읽어 분류 작업을 수행하시오.
    
    import numpy as np
    import pandas as pd
    from sklearn.model_selection import train_test_split
    import matplotlib.pyplot as plt
    import xgboost as xgb
    from xgboost import plot_importance
    
    
    df = pd.read_csv('../testdata/glass.csv')
    print(df.head(3), df.shape) # (214, 10)
    print(df.info())
    print(df.describe())
    print(df.corr())
    
    df_x = df.drop(columns = ['Type'])
    df_y = df['Type']
    
    from sklearn.preprocessing import LabelEncoder
    le = LabelEncoder()
    df_y = le.fit_transform(df_y)
    
    # train / test split
    x_train, x_test, y_train, y_test = train_test_split(df_x, df_y, test_size = 0.2, random_state = 12)
    print(x_train.shape, x_test.shape, y_train.shape, y_test.shape) # (171, 9) (43, 9) (171,) (43,)
    
    
    # model
    model = xgb.XGBClassifier(booster = 'gbtree', max_depth = 6, n_estimators = 500).fit(x_train, y_train) # 의사결정 기반(booster)
    pred = model.predict(x_test)
    print('예측값 :', pred[:10])
    print('실제값 :', y_test[:10])
    
    print('정확도 확인 방법 1')
    from sklearn import metrics
    acc = metrics.accuracy_score(y_test, pred)
    print('acc :', acc)
    
    # 중요 변수 시각화
    fig, ax = plt.subplots(figsize=(10, 12))
    plot_importance(model, ax = ax)
    plt.show()
    
    
    <console>
            RI     Na    Mg    Al     Si     K    Ca   Ba   Fe  Type
    0  1.52101  13.64  4.49  1.10  71.78  0.06  8.75  0.0  0.0     1
    1  1.51761  13.89  3.60  1.36  72.73  0.48  7.83  0.0  0.0     1
    2  1.51618  13.53  3.55  1.54  72.99  0.39  7.78  0.0  0.0     1 (214, 10)
    <class 'pandas.core.frame.DataFrame'>
    RangeIndex: 214 entries, 0 to 213
    Data columns (total 10 columns):
     #   Column  Non-Null Count  Dtype  
    ---  ------  --------------  -----  
     0   RI      214 non-null    float64
     1   Na      214 non-null    float64
     2   Mg      214 non-null    float64
     3   Al      214 non-null    float64
     4   Si      214 non-null    float64
     5   K       214 non-null    float64
     6   Ca      214 non-null    float64
     7   Ba      214 non-null    float64
     8   Fe      214 non-null    float64
     9   Type    214 non-null    int64  
    dtypes: float64(9), int64(1)
    memory usage: 16.8 KB
    None
                   RI          Na          Mg  ...          Ba          Fe        Type
    count  214.000000  214.000000  214.000000  ...  214.000000  214.000000  214.000000
    mean     1.518365   13.407850    2.684533  ...    0.175047    0.057009    2.780374
    std      0.003037    0.816604    1.442408  ...    0.497219    0.097439    2.103739
    min      1.511150   10.730000    0.000000  ...    0.000000    0.000000    1.000000
    25%      1.516522   12.907500    2.115000  ...    0.000000    0.000000    1.000000
    50%      1.517680   13.300000    3.480000  ...    0.000000    0.000000    2.000000
    75%      1.519157   13.825000    3.600000  ...    0.000000    0.100000    3.000000
    max      1.533930   17.380000    4.490000  ...    3.150000    0.510000    7.000000
    
    [8 rows x 10 columns]
                RI        Na        Mg  ...        Ba        Fe      Type
    RI    1.000000 -0.191885 -0.122274  ... -0.000386  0.143010 -0.164237
    Na   -0.191885  1.000000 -0.273732  ...  0.326603 -0.241346  0.502898
    Mg   -0.122274 -0.273732  1.000000  ... -0.492262  0.083060 -0.744993
    Al   -0.407326  0.156794 -0.481799  ...  0.479404 -0.074402  0.598829
    Si   -0.542052 -0.069809 -0.165927  ... -0.102151 -0.094201  0.151565
    K    -0.289833 -0.266087  0.005396  ... -0.042618 -0.007719 -0.010054
    Ca    0.810403 -0.275442 -0.443750  ... -0.112841  0.124968  0.000952
    Ba   -0.000386  0.326603 -0.492262  ...  1.000000 -0.058692  0.575161
    Fe    0.143010 -0.241346  0.083060  ... -0.058692  1.000000 -0.188278
    Type -0.164237  0.502898 -0.744993  ...  0.575161 -0.188278  1.000000
    
    [10 rows x 10 columns]
    (171, 9) (43, 9) (171,) (43,)
    예측값 : [1 1 1 0 5 1 2 5 0 0]
    실제값 : [1 1 4 0 5 4 2 5 0 0]
    정확도 확인 방법 1
    acc : 0.8837209302325582

     

     

     

    <선생님 코드>

    # [XGBoost 문제]  이걸로 풀어야 함
    # kaggle.com이 제공하는 'glass datasets'
    # 유리 식별 데이터베이스로 여러 가지 특징들에 의해 7가지의 label(Type)로 분리된다.
    # RI    Na    Mg    Al    Si    K    Ca    Ba    Fe    Type
    #                           ...
    # glass.csv 파일을 읽어 분류 작업을 수행하시오.
    
    import pandas as pd
    import numpy as np
    from sklearn.model_selection._split import train_test_split
    from sklearn import metrics
    import xgboost as xgb
    import matplotlib.pyplot as plt
    
    data = pd.read_csv("../testdata/glass.csv")
    print(data.columns)
    
    x = data.drop('Type', axis=1)  # Type 열은 독립 변수에서 제외
    y = data['Type']
    
    print(set(y))  # {1, 2, 3, 5, 6, 7}
    
    from sklearn.preprocessing import LabelEncoder
    le = LabelEncoder()
    y = le.fit_transform(y)
    print(y[:3], set(y)) # {0, 1, 2, 3, 4, 5}
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=12)
    
    model = xgb.XGBClassifier(booster='gbtree', n_estimators=500, random_state=12)
    model.fit(x_train,y_train)
    
    print()  
    y_pred = model.predict(x_test)
    print('실제값 :', y_pred[:5])
    print('예측값:', np.array(y_test[:5]))
    print('정확도 :', metrics.accuracy_score(y_test, y_pred))
    
    from sklearn.metrics import roc_auc_score
    xgb_roc_curve = roc_auc_score(y_test, model.predict_proba(x_test), multi_class="ovr")
    # ValueError: multi_class must be in ('ovo', 'ovr') 예외 발생 에러가 나면 multi_class="ovr"를 주자.
    print('ROC AUC : {0:.4f}'.format(xgb_roc_curve))
    
    # 중요 변수 시각화
    from xgboost import plot_importance
    plot_importance(model)
    plt.show()
    
    
    <console>
    Index(['RI', 'Na', 'Mg', 'Al', 'Si', 'K', 'Ca', 'Ba', 'Fe', 'Type'], dtype='object')
    {1, 2, 3, 5, 6, 7}
    [0 0 0] {0, 1, 2, 3, 4, 5}
    
    실제값 : [1 1 4 0 5]
    예측값: [1 1 4 0 5]
    정확도 : 0.8
    ROC AUC : 0.9565

     

     

     

    댓글

Designed by Tistory.