Python 데이터 분석

Python 데이터분석 기초 68 - SVM으로 이미지 분류

코딩탕탕 2022. 11. 24. 13:16

 

SVM으로 이미지를 분석하였다. 세계 정치인 중 일부 사진을 사용

주성분 분석으로 이미지 차원 축소,  train / test split, 시각화를 실시

 

# SVM으로 이미지 분류
# 세계 정치인 중 일부 사진을 사용

from sklearn.svm import SVC
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people
from sklearn.pipeline import make_pipeline

faces = fetch_lfw_people(min_faces_per_person = 60, color = False) # color는 True가 컬러, False 가 흑백
# print(faces)
# print(faces.DESCR)

print(faces.data[:3], ' ', faces.data.shape) # (1277, 2914)
print(faces.target, set(faces.target))
print(faces.target_names)
print(faces.images.shape) # (1277, 62, 47)

# print(faces.images[0])
# print(faces.target_names[faces.target[0]])
# plt.imshow(faces.images[0], cmap='bone')
# plt.show()

"""
fig, ax = plt.subplots(3, 5)
# print(fig)
# print(ax.flat)
for i, axi in enumerate(ax.flat):
    axi.imshow(faces.images[i], cmap='bone')
    axi.set(xticks=[], yticks=[], xlabel=faces.target_names[faces.target[i]])
plt.show()
"""

# 주성분 분석으로 이미지 차원 축소
m_pca = PCA(n_components = 150, whiten=True, random_state=0)
x_low = m_pca.fit_transform(faces.data)
print('x_low :', x_low[:1], x_low.shape) # (1277, 150)
print(m_pca.explained_variance_ratio_)

# model
m_svc = SVC(C=1)
model = make_pipeline(m_pca, m_svc) # 선처리기(주성분 분석)와 분류기를 하나의 pipeline으로 묶어 순차적으로 진행
print(model)
# Pipeline(steps=[('pca', PCA(n_components=150, random_state=0, whiten=True)), ('svc', SVC(C=1))])

# train / test split
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(faces.data, faces.target, random_state=1)
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape) # (957, 2914) (320, 2914) (957,) (320,)

model.fit(x_train, y_train)
pred = model.predict(x_test)
print('예측값 :', pred[:10])
print('실제값 :', y_test[:10])

from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
print('classification_report : \n', classification_report(y_test, pred, target_names=faces.target_names))
con_mat = confusion_matrix(y_test, pred)
print('con_mat :\n', con_mat)
print('acc :', accuracy_score(y_test,pred)) # 0.771875

fig, ax = plt.subplots(4, 6)
for i, axi in enumerate(ax.flat):
    axi.imshow(x_test[i].reshape(62, 47), cmap='bone')
    axi.set(xticks=[], yticks=[])
    axi.set_ylabel(faces.target_names[pred[i]].split()[-1], color = 'black' if pred[i] == y_test[i] else 'red')
    fig.suptitle('pred result', size = 14)

plt.show()



<console>
[[112.       134.       148.33333  ...  56.666668  59.        57.666668]
 [ 90.333336  96.       103.333336 ...  98.333336 101.666664 104.666664]
 [ 42.666668  39.666668  55.666668 ... 103.       141.33333  168.      ]]   (1277, 2914)
[2 3 3 ... 5 3 5] {0, 1, 2, 3, 4, 5, 6}
['Ariel Sharon' 'Colin Powell' 'Donald Rumsfeld' 'George W Bush'
 'Gerhard Schroeder' 'Junichiro Koizumi' 'Tony Blair']
(1277, 62, 47)
x_low : [[ 1.3677825   1.2960404  -2.1041217  -2.1586986  -0.435201   -0.4743769
   0.9264316   1.3343427   0.43732992 -0.9135895   2.2486153  -0.7178624
  -0.4068734  -1.3802292  -0.29159927 -0.05936835  0.73455346  2.8088133
  -1.568095   -1.8362647  -0.84274054  0.770162   -1.5726647   0.32385513
   1.4969673  -0.42508847 -1.0940268  -1.0819191   0.97768533 -0.68225
   0.99427944 -0.9928934  -1.200369    3.6187801  -2.5437155  -0.33215645
  -0.322823   -0.802046    2.6174054  -0.2712837   1.3795964  -1.291398
  -1.331111   -0.9579771  -0.890986    1.3076365   2.6342285   2.270725
  -1.8776215   0.16426027  0.31948033  1.0879174   0.9021022  -0.37024105
   1.5219023  -0.11264389  1.4595239  -3.2084208   0.698089    1.4241335
   1.9436735  -1.1137776  -2.1509817   0.26504955  0.71430373 -2.027011
  -1.0641276   2.2356846   1.007661   -0.04664142 -0.29766974  0.45365828
   1.9266397   0.6258489  -1.0495006  -0.3399345   2.4036052  -1.8655828
   1.5655178  -0.56435126 -1.2220348   1.7108462  -0.9093841   0.6768287
   0.9239087  -1.8770708   0.75470287 -1.1503301  -0.8594827   0.54812485
   0.6937472   0.16379279  0.46024537 -1.4094268   1.4134499   0.21572134
  -0.40187916  1.4675306   0.28865305  0.876322   -0.67264843 -0.14912999
   0.11965517 -1.4723855  -1.3445941   0.15150931 -0.3397091  -1.1231271
  -1.5955024   0.752204   -0.22630595  0.44366688  0.5599929  -0.8407435
   0.46005732 -1.164607    2.4685104  -1.5612679   1.2394214  -0.40660897
  -2.200386    0.9096232   1.6126409   1.5690994   1.746112    2.0644932
   3.3641155  -0.04552075 -1.9216464  -1.6330186   0.18395203  2.6069992
   0.82801706 -1.3873518   1.7612125  -0.1259181  -1.6687584  -0.9706199
   0.5388545  -2.8701713   1.7484329   0.21548973  1.2080355  -0.58761966
   2.8680964  -1.4097466  -0.13887545  0.18880442  0.5082952  -0.9576872 ]] (1277, 150)
[0.18526202 0.14934969 0.07120554 0.05920597 0.051394   0.02949642
 0.02496833 0.02074919 0.01956663 0.01904831 0.01543423 0.01405212
 0.01220285 0.01094312 0.01035904 0.00964713 0.00914998 0.0087581
 0.00819499 0.0070755  0.00694857 0.00650467 0.00608582 0.0058524
 0.0054521  0.0051052  0.00488408 0.00477356 0.00445718 0.00432005
 0.00402443 0.00379395 0.00364652 0.00353308 0.00347345 0.00328357
 0.00317543 0.00313066 0.00305568 0.00291224 0.00284832 0.00276779
 0.00271521 0.0026042  0.002454   0.0023958  0.00233838 0.00233531
 0.00230086 0.00213147 0.00208963 0.00205583 0.00202007 0.00195988
 0.00194176 0.00190267 0.0018513  0.00184272 0.00178145 0.00172218
 0.00171553 0.00166959 0.00161362 0.00159158 0.00154077 0.00151684
 0.00149804 0.00146335 0.00143073 0.00141307 0.00137757 0.0013638
 0.00134382 0.00131184 0.00128164 0.00126804 0.00123963 0.00122841
 0.00120151 0.00119128 0.00117175 0.00115664 0.0011434  0.00108655
 0.00107621 0.00107286 0.00103867 0.00101718 0.00100961 0.00099422
 0.00097838 0.00095651 0.00095328 0.00094091 0.0009277  0.00091337
 0.00088801 0.00086844 0.00086317 0.00084927 0.00083797 0.00082896
 0.00080492 0.00079063 0.00077659 0.00075959 0.00075335 0.00075086
 0.00072738 0.0007242  0.00071693 0.00069781 0.00069663 0.00068499
 0.00068163 0.00066933 0.00066214 0.00065029 0.00064562 0.00063024
 0.00061722 0.00060555 0.0005976  0.00058633 0.00058511 0.00058192
 0.000573   0.00056909 0.00056206 0.00055223 0.0005379  0.00053007
 0.00052423 0.00051866 0.00051385 0.00050846 0.00050295 0.0004923
 0.0004828  0.00048074 0.00047113 0.00046349 0.00045476 0.00044763
 0.00044282 0.00043672 0.00043104 0.00042354 0.00042013 0.00041644]
Pipeline(steps=[('pca', PCA(n_components=150, random_state=0, whiten=True)),
                ('svc', SVC(C=1))])
(957, 2914) (320, 2914) (957,) (320,)
예측값 : [1 3 3 3 3 3 3 3 3 5]
실제값 : [1 2 0 5 3 4 3 6 3 5]
classification_report : 
                    precision    recall  f1-score   support

     Ariel Sharon       1.00      0.32      0.48        19
     Colin Powell       0.83      0.89      0.86        54
  Donald Rumsfeld       1.00      0.38      0.55        34
    George W Bush       0.69      0.99      0.81       138
Gerhard Schroeder       1.00      0.50      0.67        28
Junichiro Koizumi       1.00      0.67      0.80        18
       Tony Blair       0.94      0.59      0.72        29

         accuracy                           0.77       320
        macro avg       0.92      0.62      0.70       320
     weighted avg       0.83      0.77      0.75       320

con_mat :
 [[  6   2   0  11   0   0   0]
 [  0  48   0   6   0   0   0]
 [  0   6  13  15   0   0   0]
 [  0   1   0 137   0   0   0]
 [  0   1   0  12  14   0   1]
 [  0   0   0   6   0  12   0]
 [  0   0   0  12   0   0  17]]
acc : 0.771875

 

 

예측 결과 시각화 black = True, red = False