-
TensorFlow 기초 13 - 다중선형회귀모델(scaling) - 정규화, 표준화 validation_splitTensorFlow 2022. 11. 30. 17:41
다중선형회귀모델
scaling : feature 간 단위의 차이가 클 경우 정규화/표준화 작업이 효과적 - label에는 적용하지 않는다.
표준화 : (요소값 - 평균) / 표준편차
정규화 : (요소값 - 최소값) / (최대값 - 최소값)StandardScaler : 표준화, 이상치가 있는 경우 불균형
MinMaxScaler : 정규화, 이상치에 민감
RobustScaler : 이상치의 영향을 최소화 한다.validation_split : fit() 학습 시에 이루어지는 것으로 train과 test로 나눠진 것에서 train을 다시 8:2로 잘라 학습 도중 검정도 하겠다는 의미이다. fit() 에 validation_split을 주었다면 val_loss, val_mse가 history에 생긴다.
# 다중선형회귀모델 # scaling : feature 간 단위의 차이가 클 경우 정규화/표준화 작업이 효과적 - label에는 적용하지 않는다. # 표준화 : (요소값 - 평균) / 표준편차 # 정규화 : (요소값 - 최소값) / (최대값 - 최소값) import pandas as pd import numpy as np from keras.models import Sequential from keras.layers import Dense from sklearn.preprocessing import MinMaxScaler, minmax_scale, StandardScaler, RobustScaler # StandardScaler : 표준화, 이상치가 있는 경우 불균형 # MinMaxScaler : 정규화, 이상치에 민감 # RobustScaler : 이상치의 영향을 최소화 한다. data = pd.read_csv('../testdata/Advertising.csv') del data['no'] # no columns 삭제 print(data.head(2)) print(data.corr()) fdata = data[['tv', 'radio', 'newspaper']] ldata = data[['sales']] print(fdata.head(2)) print(ldata.head(2)) # 스케일링 방법1 - 정규화 # scaler = MinMaxScaler(feature_range=(0, 1)) # fedata = scaler.fit_transform(fdata) # print(fedata) fedata = minmax_scale(fdata, feature_range=(0, 1), axis=0, copy=True) # print(fedata) print(fdata.head(2)) print(fedata[:2]) # train / test from sklearn.model_selection import train_test_split x_train, x_test, y_train, y_test = train_test_split(fedata, ldata, shuffle=True, test_size=0.3, random_state=123) model = Sequential() model.add(Dense(20, input_dim=3, activation='linear')) # hidden에는 activation = 'relu'도 가능 model.add(Dense(10, activation='linear')) model.add(Dense(1, activation='linear')) model.compile(optimizer='adam', loss='mse', metrics=['mse']) print(model.summary()) # 모델 구조만 시각화 import tensorflow as tf # tf.keras.utils.plot_model(model, 'lin_model.png') history = model.fit(x_train, y_train, epochs=100, batch_size=32, verbose=0, validation_split=0.2) # 0.7의 train을 다시 8:2로 잘라 학습 도중 검정도 하겠다는 의미이다. # 모델 평가 후 score 확인 loss = model.evaluate(x_test, y_test, batch_size=32, verbose=0) print('loss :', loss[0]) # history 값 print(history.history) # fit() 에 validation_split을 주었다면 val_loss, val_mse가 history에 생긴다. print(history.history['loss']) print(history.history['val_loss']) print(history.history['mse']) print(history.history['val_mse']) import matplotlib.pyplot as plt plt.plot(history.history['loss'], label='loss') plt.plot(history.history['val_loss'], label='val_loss') plt.legend() plt.show() <console> tv radio newspaper sales 0 230.1 37.8 69.2 22.1 1 44.5 39.3 45.1 10.4 tv radio newspaper sales tv 1.000000 0.054809 0.056648 0.782224 radio 0.054809 1.000000 0.354104 0.576223 newspaper 0.056648 0.354104 1.000000 0.228299 sales 0.782224 0.576223 0.228299 1.000000 tv radio newspaper 0 230.1 37.8 69.2 1 44.5 39.3 45.1 sales 0 22.1 1 10.4 tv radio newspaper 0 230.1 37.8 69.2 1 44.5 39.3 45.1 [[0.77578627 0.76209677 0.60598065] [0.1481231 0.79233871 0.39401935]] Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 20) 80 dense_1 (Dense) (None, 10) 210 dense_2 (Dense) (None, 1) 11 ================================================================= Total params: 301 Trainable params: 301 Non-trainable params: 0 _________________________________________________________________ None loss : 5.326688766479492 {'loss': [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801], 'mse': [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801], 'val_loss': [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145], 'val_mse': [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145]} [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801] [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145] [220.86392211914062, 216.654541015625, 212.54075622558594, 208.59365844726562, 204.64442443847656, 200.69970703125, 196.8850555419922, 192.8991241455078, 188.89749145507812, 185.01803588867188, 180.98472595214844, 176.819091796875, 172.68994140625, 168.394775390625, 163.88644409179688, 159.5344696044922, 154.7161102294922, 149.84657287597656, 144.78062438964844, 139.7972869873047, 134.44863891601562, 129.08290100097656, 123.25224304199219, 117.62623596191406, 111.58130645751953, 105.47517395019531, 99.46146392822266, 93.2922134399414, 86.93135070800781, 80.5485610961914, 74.51576232910156, 68.19831848144531, 62.22127151489258, 56.33311080932617, 50.62906265258789, 45.09479522705078, 40.17760467529297, 35.40900802612305, 30.915903091430664, 27.265493392944336, 23.78342628479004, 20.69056510925293, 18.07109260559082, 15.832938194274902, 14.059842109680176, 12.700156211853027, 11.545656204223633, 10.673049926757812, 10.111808776855469, 9.743409156799316, 9.39438533782959, 9.16877555847168, 9.032331466674805, 8.899121284484863, 8.818928718566895, 8.738425254821777, 8.669133186340332, 8.584636688232422, 8.51762866973877, 8.443910598754883, 8.381913185119629, 8.30812931060791, 8.24895191192627, 8.173993110656738, 8.114842414855957, 8.034791946411133, 7.973105430603027, 7.912236213684082, 7.8424482345581055, 7.771606922149658, 7.70814847946167, 7.639877796173096, 7.587063789367676, 7.519417762756348, 7.458906650543213, 7.3990092277526855, 7.33789587020874, 7.2856245040893555, 7.226649761199951, 7.179852485656738, 7.109018802642822, 7.063086032867432, 7.003632545471191, 6.942445278167725, 6.892360687255859, 6.840705394744873, 6.7865705490112305, 6.7320556640625, 6.6846842765808105, 6.635317802429199, 6.582118034362793, 6.534093379974365, 6.485466957092285, 6.441600799560547, 6.393481254577637, 6.345686912536621, 6.297264575958252, 6.2458086013793945, 6.203938961029053, 6.157650947570801] [287.26190185546875, 282.3590393066406, 277.51446533203125, 272.6730041503906, 267.85101318359375, 263.0533752441406, 258.25299072265625, 253.459716796875, 248.6145782470703, 243.67581176757812, 238.6717529296875, 233.59182739257812, 228.34561157226562, 222.9475860595703, 217.4134063720703, 211.6157684326172, 205.69094848632812, 199.54212951660156, 193.20950317382812, 186.6103515625, 179.8234405517578, 172.80113220214844, 165.67715454101562, 158.23936462402344, 150.66233825683594, 142.92640686035156, 135.0044708251953, 127.00032806396484, 119.00638580322266, 111.02684783935547, 102.97276306152344, 95.03031158447266, 87.1689682006836, 79.51530456542969, 72.15973663330078, 65.18345642089844, 58.47574234008789, 52.26217269897461, 46.557594299316406, 41.236873626708984, 36.52107620239258, 32.38870620727539, 28.768146514892578, 25.66649055480957, 23.017333984375, 20.7474365234375, 18.920696258544922, 17.460891723632812, 16.270702362060547, 15.317413330078125, 14.608662605285645, 14.084503173828125, 13.65961742401123, 13.325175285339355, 13.034199714660645, 12.817474365234375, 12.648688316345215, 12.539387702941895, 12.448770523071289, 12.355558395385742, 12.286001205444336, 12.185823440551758, 12.071789741516113, 12.017645835876465, 11.967878341674805, 11.886987686157227, 11.791418075561523, 11.73316764831543, 11.650167465209961, 11.550652503967285, 11.454337120056152, 11.374665260314941, 11.266190528869629, 11.186368942260742, 11.089245796203613, 10.996627807617188, 10.914887428283691, 10.833882331848145, 10.74878978729248, 10.723843574523926, 10.650252342224121, 10.597186088562012, 10.510519027709961, 10.414993286132812, 10.327677726745605, 10.23946475982666, 10.13911247253418, 10.076276779174805, 10.001184463500977, 9.906207084655762, 9.855804443359375, 9.789847373962402, 9.742588996887207, 9.645662307739258, 9.533889770507812, 9.435942649841309, 9.37006664276123, 9.330156326293945, 9.29210376739502, 9.246724128723145]
loss와 val_loss의 시각화 둘의 선이 비슷한 경로를 이루어야 한다. val_loss가 혼자 높아지면 과적합이다.
*** 전통적인 방법으로 선형회귀분석의 기존 가정 충족 조건 확인 ***
1) 정규성 : 독립변수들의 잔차항이 정규분포를 따라야 한다.
2) 독립성 : 독립변수들 간의 값이 서로 관련성이 없어야 한다.
3) 선형성 : 독립변수(feature)의 변화에 따라 종속변수도 일정 크기로 변화해야 한다.
4) 등분산성 : 그룹간의 분산이 유사해야 한다. 독립변수의 모든 값에 대한 오차(잔차)들의 분산은 일정해야 한다. 특정한 패턴 없이 고르게 분포되어야 한다.
5) 다중공선성 : 다중회귀 분석 시 3개 이상의 독립변수 간에 강한 상관관계가 있어서는 안 된다.Daum 카페
cafe.daum.net
colab에서 실습하려면 이곳을 참조
'TensorFlow' 카테고리의 다른 글
다중선형회귀 예제 - 자전거 공유 시스템 분석 (0) 2022.12.01 다중선현회귀모델 예제 - 주식 데이터로 예측 모형 작성. 전날 데이터로 다음날 종가 예측(train/test split, validation_split) (0) 2022.12.01 TensorFlow 기초 12 - 다중선형회귀모델 작성 후 텐서보드(모델의 구조 및 학습과정/결과를 시각화) - (0) 2022.11.30 단순선형회귀 방법 1, 방법 2 예제(Sequential api, Function api) (0) 2022.11.30 TensorFlow 기초 11 - 단순선형회귀모델 작성 : 방법 3가지(다중 입출력 모델) (0) 2022.11.30