ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • TensorFlow 기초 13 - 다중선형회귀모델(scaling) - 정규화, 표준화 validation_split
    TensorFlow 2022. 11. 30. 17:41

     

     

    다중선형회귀모델

    scaling : feature 간 단위의 차이가 클 경우 정규화/표준화 작업이 효과적 - label에는 적용하지 않는다.
    표준화 : (요소값 - 평균) / 표준편차
    정규화 : (요소값 - 최소값) / (최대값 - 최소값)

     

    StandardScaler : 표준화, 이상치가 있는 경우 불균형
    MinMaxScaler : 정규화, 이상치에 민감
    RobustScaler : 이상치의 영향을 최소화 한다.

     

    validation_split : fit() 학습 시에 이루어지는 것으로 train과 test로 나눠진 것에서 train을 다시 8:2로 잘라 학습 도중 검정도 하겠다는 의미이다. fit() 에 validation_split을 주었다면 val_loss, val_mse가 history에 생긴다.

     

    # 다중선형회귀모델
    # scaling : feature 간 단위의 차이가 클 경우 정규화/표준화 작업이 효과적 - label에는 적용하지 않는다.
    # 표준화 : (요소값 - 평균) / 표준편차
    # 정규화 : (요소값 - 최소값) / (최대값 - 최소값)
    
    import pandas as pd
    import numpy as np
    from keras.models import Sequential
    from keras.layers import Dense
    from sklearn.preprocessing import MinMaxScaler, minmax_scale, StandardScaler, RobustScaler
    # StandardScaler : 표준화, 이상치가 있는 경우 불균형
    # MinMaxScaler : 정규화, 이상치에 민감
    # RobustScaler : 이상치의 영향을 최소화 한다.
    
    data = pd.read_csv('../testdata/Advertising.csv')
    del data['no'] # no columns 삭제
    print(data.head(2))
    print(data.corr())
    
    fdata = data[['tv', 'radio', 'newspaper']]
    ldata = data[['sales']]
    print(fdata.head(2))
    print(ldata.head(2))
    
    # 스케일링 방법1 - 정규화
    # scaler = MinMaxScaler(feature_range=(0, 1))
    # fedata = scaler.fit_transform(fdata)
    # print(fedata)
    
    fedata = minmax_scale(fdata, feature_range=(0, 1), axis=0, copy=True)
    # print(fedata)
    print(fdata.head(2))
    print(fedata[:2])
    
    # train / test
    from sklearn.model_selection import train_test_split
    x_train, x_test, y_train, y_test = train_test_split(fedata, ldata, shuffle=True,
                                                        test_size=0.3, random_state=123)
    
    model = Sequential()
    model.add(Dense(20, input_dim=3, activation='linear')) # hidden에는 activation = 'relu'도 가능
    model.add(Dense(10, activation='linear'))
    model.add(Dense(1, activation='linear'))
    
    model.compile(optimizer='adam', loss='mse', metrics=['mse'])
    print(model.summary())
    
    # 모델 구조만 시각화
    import tensorflow as tf
    # tf.keras.utils.plot_model(model, 'lin_model.png')
    
    history = model.fit(x_train, y_train, epochs=100,
                        batch_size=32, verbose=0,
                        validation_split=0.2) # 0.7의 train을 다시 8:2로 잘라 학습 도중 검정도 하겠다는 의미이다.
    
    # 모델 평가 후 score 확인
    loss = model.evaluate(x_test, y_test, batch_size=32, verbose=0)
    print('loss :', loss[0])
    
    # history 값
    print(history.history) # fit() 에 validation_split을 주었다면 val_loss, val_mse가 history에 생긴다.
    print(history.history['loss'])
    print(history.history['val_loss'])
    print(history.history['mse'])
    print(history.history['val_mse'])
    
    import matplotlib.pyplot as plt
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.legend()
    plt.show()
    
    
    <console>
          tv  radio  newspaper  sales
    0  230.1   37.8       69.2   22.1
    1   44.5   39.3       45.1   10.4
                     tv     radio  newspaper     sales
    tv         1.000000  0.054809   0.056648  0.782224
    radio      0.054809  1.000000   0.354104  0.576223
    newspaper  0.056648  0.354104   1.000000  0.228299
    sales      0.782224  0.576223   0.228299  1.000000
          tv  radio  newspaper
    0  230.1   37.8       69.2
    1   44.5   39.3       45.1
       sales
    0   22.1
    1   10.4
          tv  radio  newspaper
    0  230.1   37.8       69.2
    1   44.5   39.3       45.1
    [[0.77578627 0.76209677 0.60598065]
     [0.1481231  0.79233871 0.39401935]]
    Model: "sequential"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     dense (Dense)               (None, 20)                80        
                                                                     
     dense_1 (Dense)             (None, 10)                210       
                                                                     
     dense_2 (Dense)             (None, 1)                 11        
                                                                     
    =================================================================
    Total params: 301
    Trainable params: 301
    Non-trainable params: 0
    _________________________________________________________________
    None
    loss : 5.326688766479492
    {'loss': [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801], 'mse': [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801], 'val_loss': [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145], 'val_mse': [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145]}
    [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801]
    [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145]
    [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801]
    [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145]

     

     

    loss와 val_loss의 시각화

    둘의 선이 비슷한 경로를 이루어야 한다. val_loss가 혼자 높아지면 과적합이다.

     

     

     

     

    *** 전통적인 방법으로 선형회귀분석의 기존 가정 충족 조건 확인 ***

    1) 정규성 : 독립변수들의 잔차항이 정규분포를 따라야 한다.
    2) 독립성 : 독립변수들 간의 값이 서로 관련성이 없어야 한다.
    3) 선형성 : 독립변수(feature)의 변화에 따라 종속변수도 일정 크기로 변화해야 한다.
    4) 등분산성 : 그룹간의 분산이 유사해야 한다. 독립변수의 모든 값에 대한 오차(잔차)들의 분산은 일정해야 한다. 특정한 패턴 없이 고르게 분포되어야 한다.
    5) 다중공선성 : 다중회귀 분석 시 3개 이상의 독립변수 간에 강한 상관관계가 있어서는 안 된다.

     

     

     

    Daum 카페

     

    cafe.daum.net

    colab에서 실습하려면 이곳을 참조

     

     

    댓글

Designed by Tistory.