확률적 식별 모델
- 다중클래스 로지스틱 회귀
- 우도함수 소프트맥스, 시그모이드의 함수들 사이의 관계를 잘 기억할것
체인룰을 쓰기 위해 함수관계를 잘 이해해야 한다.
- Gradient Descent(batch)
- 사이킷런의 make_classification을 통해 쉽게 데이터를 만들어낼 수 있다.
- 로지스틱 리그레션 모델에서 반드시 써야하는 함수가 시그모이드 함수이다. 지수함수에 1을더해서 나눈것을 시그모이드 함수로 정의한다.
- Cost를 계산하는 함수 : 현재의 파라미터값 w가 주어져 있을 때 입력 X와 목표값 t에 대해 cost를 계산한다.
- 전체 데이터를 한꺼번에 넣어서 그래디언트를 업데이트 하는데, 이를 배치 업데이트라고한다. 배치란 데이터 전체를 한꺼번에 원샷으로 다 쓴다는 의미이다. 딥러닝에서 말하는 배치사이즈는 보통 미니배치를 말하는 것이다. 여기서 배치라는 것은 전체 데이터를 다 쓴다는 것이다.
- 1이중분류(binary classification)
이중분류 문제의 경우 클래스가 비슷한 분포를 가지고 있지 않을 때 불균형의 문제가 있어 정확도만 구하는 것은 좋은 지표가 될 수 없다.
다음 것들을 고려해야 한다.
- 오차행렬(Confusion matrix)
- 오차행렬을 구하기 위해서 사이킷런에 포함된 오차행렬을 사용한다.
- TP(True Positive) : 모델이 Positive인데 실제로도 positive라고 예측한 경우
- FP(False Positive) : 모델이 negative인데 positive라고 예측한 경우
- TN(True Negative) : 모델이 negative인데 negative라고 예측한 경우
- FN(False Negative) : 모델이 Positive인데 negative라고 예측한 경우
- 정밀도 : 모델이 positive라고 했을 때 그중 몇개가 정말로 Positive인지의 비율
- 재현율 : 데이터에서 positive인 것들의 개수 중에서 모델이 얼마나 잘 positive를 찾아냈는지에 대한 비율 어떤 모델은 precision이 높고 recall이 낮으며, 어떤 모델은 그 반대일 수 있다. 어떤 모델이 좋은 모델일까? 이것은 경우에 따라서 다를 수 있다.
의료진단의 경우 precision보다 recall이 중요하다. FP인 경우 병이 있다고 판단했는데 실제로는 병이 없는 경우이다. 번거롭지만 리스크가 큰 것은 아니다.
하지만 FN인 경우 실제로 병이 있는데 병이 없다고 판단하게 되면 리스크가 크다. 이런 경우 recall이 높아야 한다.
스팸을 분류하는 경우를 생각해보자. Precision을 높이기 위해 recall이 낮아지면, 스팸이 아닌데 스팸이라고 판단되었는데 중요한 이메일이면 리스크가 크다. 따라서 스팸을 분류하는 경우도 recall이 더 중요하다.
동영상을 분류하는 경우(어린이에게 적절한 동영상인지) precision이 더 중요하다. 분류기가 잘못해서 좋은 동영상이 아닌데 보여주는 것은, 그 하나 때문에 큰 리스크가 발생할 수 있다.
하나의 모델 안에서 precision과 recall을 튜닝할 수 있다.
위의 사진을 보면 오른쪽으로 갈수록 이 이미지들이 높은 스코어를 가진다.
주어진 이미지에 대해 positive/negative를 판별하기 위해 스코어를 보는데, 스코어 값이 특정 기준보다 크냐/작냐에 따라 판별한다.
그 기준을 threshold라고 한다.
Threshold을 낮게 잡으면 모든 경우에 positive라고 말한다. Recall이 1에 가깝게 된다. 존재하는 모든 positive를 positive라고 예측하기 때문이다. 반면 precision은 낮을 수밖에 없다.
반면 Threshold를 높게 작으면 recall은 낮아지고 precision은 높아진다.
그래서 중간정도의 threshold를 많이 사용한다. 하지만 앞의 암이나 스팸 메일등 precision과 recall의 중요도에 따라 달라지기도 한다.
어떤지점이 가장 좋은 지점일까? 급격한 변화가 일어나기 전의 지점을 threshold로 지정하면 좋은 trade-off라고 할 수 있다.
하나의 모델 안에서 precision, recall을 조정하기 위해서는 모델의 예측값이 참이냐 거짓이냐 보다는 스코어를 알아야 한다.
다중 분류(multiclass classification)
이진문제가 아닌 다중분류의 경우에는 클래스가 비슷한 분포를 가지고 있기 때문에 그냥 정확도만 구해도 불균형의 문제가 별로 없어서 정밀도나 재현율을 구하지 않고 정확도를 구하는 것도 괜찮은 지표가 된다.
- Data Augmentation
- 학습한 모델의 성능을 향상시키기 위해 data augmentation이라는 방법을 사용
- 이미지에 변형을 가해서 추가적인 데이터를 만들어서 학습데이터에 포함해 모델을 새로 학습하면 좀더 안정적인 모델을 만들 수 있다.
실습 - MNIST
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)
# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"
# Common imports
import numpy as np
import os
# to make this notebook's output stable across runs
np.random.seed(42)
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)
def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
print("Saving figure", fig_id)
if tight_layout:
plt.tight_layout()
plt.savefig(path, format=fig_extension, dpi=resolution)
MNIST 데이터¶
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist.keys()
dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
X, y = mnist["data"], mnist["target"]
X.shape
(70000, 784)
X = np.array(X)
y
0 5
1 0
2 4
3 1
4 9
..
69995 2
69996 3
69997 4
69998 5
69999 6
Name: class, Length: 70000, dtype: category
Categories (10, object): ['0', '1', '2', '3', ..., '6', '7', '8', '9']
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
some_digit = X[2]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off")
save_fig("some_digit_plot")
plt.show()
Saving figure some_digit_plot
some_digit
array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 67., 232., 39., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 62., 81., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 120., 180., 39., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 126., 163., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 2., 153., 210., 40., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 220., 163., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 27., 254., 162., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 222., 163., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 183., 254., 125., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 46., 245., 163.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 198., 254., 56., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 120., 254., 163., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 23., 231., 254., 29.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 159., 254.,
120., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 163., 254., 216., 16., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 159., 254., 67., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 14., 86., 178., 248., 254., 91.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 159.,
254., 85., 0., 0., 0., 47., 49., 116., 144., 150., 241.,
243., 234., 179., 241., 252., 40., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 150., 253., 237., 207., 207., 207.,
253., 254., 250., 240., 198., 143., 91., 28., 5., 233., 250.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 119., 177., 177., 177., 177., 177., 98., 56., 0., 0.,
0., 0., 0., 102., 254., 220., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 169., 254.,
137., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 169., 254., 57., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 169.,
254., 57., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 169., 255., 94., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
169., 254., 96., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 169., 254., 153., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 169., 255., 153., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 96., 254., 153., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.])
#배열이 아닌 정수로 저장
y = y.astype(np.uint8)
y
0 5
1 0
2 4
3 1
4 9
..
69995 2
69996 3
69997 4
69998 5
69999 6
Name: class, Length: 70000, dtype: uint8
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.binary,
interpolation="nearest")
plt.axis("off")
def plot_digits(instances, images_per_row=10, **options):
size = 28
images_per_row = min(len(instances), images_per_row)
images = [instance.reshape(size,size) for instance in instances]
n_rows = (len(instances) - 1) // images_per_row + 1
row_images = []
n_empty = n_rows * images_per_row - len(instances)
images.append(np.zeros((size, size * n_empty)))
for row in range(n_rows):
rimages = images[row * images_per_row : (row + 1) * images_per_row]
row_images.append(np.concatenate(rimages, axis=1))
image = np.concatenate(row_images, axis=0)
plt.imshow(image, cmap = mpl.cm.binary, **options)
plt.axis("off")
#최초 100개의 이미지 그리기
plt.figure(figsize=(9,9))
example_images = X[:100]
plot_digits(example_images, images_per_row=10)
save_fig("more_digits_plot")
plt.show()
Saving figure more_digits_plot
y[0]
5
#학습/ 테스트 데이터 분리
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
이진분류기 (Binary classifier)¶
문제를 단순화해서 숫자 5만 식별해보자.
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
y_train_5
0 True
1 False
2 False
3 False
4 False
...
59995 False
59996 False
59997 True
59998 False
59999 False
Name: class, Length: 60000, dtype: bool
로지스틱 회귀 모델을 사용해보자.
import warnings
warnings.filterwarnings(action='ignore')
#사이킷런 라이브러리 활용
from sklearn.linear_model import LogisticRegression
log_clf = LogisticRegression(random_state=0).fit(X_train, y_train_5)
log_clf.predict([X[0],X[1],X[2]])
array([ True, False, False])
교차 검증을 사용해서 평가해보자. (fold 갯수 3)
from sklearn.model_selection import cross_val_score
cross_val_score(log_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.97525, 0.9732 , 0.9732 ])
모든 교차 검증 폴드에 대해 정확도가 97% 이상임. 모델이 좋아 보이는가? 이것이 정말 좋은 결과일까? 찾아보자. Never5Classifier 선언 : 무조건 5가 아니다(0) 라고 예측하는 classifier
from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
def fit(self, X, y=None):
pass
def predict(self, X):
return np.zeros(len(X), dtype=bool)
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.91125, 0.90855, 0.90915])
never_5_clf.predict(X)
array([False, False, False, ..., False, False, False])
이미지의 10%만 숫자 5이기 때문에 무조건 5가 아닌 것으로 예측해도 정확도는 90%가 된다. 즉 97% 정확도가 아주 좋은 예측률이 아닐 수도 있다는 뜻.
목표값(클래스)들이 불균형인 경우에 정확도(accuracy)는 좋은 지표가 아니다.
그렇다면 어떻게 해야할까? -> Precision/Recall matrix를 활용하거나 , F1-score 개선
목표값의 분포가 50:50에 가까우면 accuracy를 써도 된다
오차행렬 (Confusion matrix)¶
#먼저 cross-validation을 활용해서 예측값을 저장하자
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(log_clf, X_train, y_train_5, cv=3)
y_train_pred.shape
(60000,)
#오차행렬 생성
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
array([[54038, 541],
[ 1026, 4395]])
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred)
0.8903970826580226
4395/(4395+541)
0.8903970826580226
recall_score(y_train_5, y_train_pred)
0.8107360265633647
4395/(4395+1026)
0.8107360265633647
confusion_matrix(y_train_5, never_5_clf.predict(X)[:60000])
array([[54579, 0],
[ 5421, 0]])
precision_score(y_train_5, never_5_clf.predict(X)[:60000])
0.0
recall_score(y_train_5, never_5_clf.predict(X)[:60000])
0.0
Error cases 조사하기¶
89%의 정밀도라면 완벽한 모델은 아니다.
어떨 때 에러가 발생하는지 살펴보자
errors = (y_train_pred != y_train_5)
errors
0 False
1 False
2 False
3 False
4 False
...
59995 False
59996 False
59997 False
59998 False
59999 False
Name: class, Length: 60000, dtype: bool
# 에러인 경우 중에서 100개를 뽑아서 보자
plt.figure(figsize=(9,9))
plot_digits(X_train[errors][:100], images_per_row=10)
save_fig("more_digits_plot")
plt.show()
Saving figure more_digits_plot
Precision/Recall Trade-off¶
두개의 모델이
하나는 정밀도가 높지만 재현률이 낮고,
하나는 정밀도가 낮지만 재현률이 높다면,
어떤 모델을 선택해야 할까?
ex 1> 암 진단 모델의 경우
precision에서 문제가 있는 경우?
cancer 있다고 진단했는데 실제로는 cancer가 아닌 경우
recall에서 문제가 있는 경우?
cancer 없다고 진단했는데 실제로는 cancer가 있는 경우
후자가 훨씬 더 심각한 결과를 초래하는 사건 --> recall이 높아야 한다.
ex 2> spam mail 분류 모델의 경우
precision에서 문제가 있는 경우?
spam 아니라고 했는데 실제로는 spam인 경우
recall에서 문제가 있는 경우?
spam 이라고 진단했는데 실제로는 spam 아닌 경우
전자가 훨씬 더 심각한 결과를 초래하는 사건 --> recall이 높아야 한다.
ex 3> 어린이 동영상 등급 분류기 -> precision이 중요하다
좋은 동영상이라고 했는데 실제로는 좋은 동영상이 아닌 경우
$ $
여러 thresholds가 있을 수 있다. 어떤 것을 선택해야 할까?
#에러의 인덱스를 뽑아보자.
for i in range(len(errors)):
if errors[i]:
print(i)
48
132
138
173
244
262
278
502
528
540
554
558
559
610
635
690
748
769
778
832
836
899
916
924
1015
1021
1024
1029
1032
1089
1102
1104
1111
1145
1151
1222
1276
1278
1298
1311
1324
1325
1328
1346
1356
1376
1468
1587
1651
1682
1784
1930
1961
2000
2028
2148
2172
2184
2209
2211
2272
2302
2308
2368
2373
2395
2428
2566
2622
2636
2744
2758
2780
2803
2958
3013
3034
3070
3095
3204
3205
3216
3289
3401
3416
3502
3524
3580
3638
3648
3872
4004
4066
4148
4184
4192
4231
4304
4362
4416
4475
4562
4590
4596
4622
4645
4665
4666
4786
4801
4820
4847
4908
4941
5066
5080
5157
5177
5188
5250
5255
5303
5382
5506
5526
5554
5632
5666
5678
5737
5738
5752
5757
5771
5798
5839
5847
5857
5904
5925
5947
5956
6066
6086
6096
6106
6171
6236
6370
6418
6428
6450
6472
6504
6636
6644
6735
6836
6844
6848
6895
6920
6929
6943
7032
7044
7080
7112
7134
7216
7232
7270
7281
7286
7346
7354
7368
7374
7498
7544
7546
7584
7639
7833
7842
7884
7917
7962
8009
8062
8178
8190
8307
8339
8414
8419
8435
8488
8504
8549
8606
8623
8624
8639
8662
8678
8712
8713
8731
8757
8814
8816
8898
8940
9028
9084
9146
9340
9344
9376
9378
9384
9397
9450
9464
9514
9528
9534
9568
9646
9650
9717
9770
9814
9898
9952
9958
10030
10044
10146
10155
10211
10244
10258
10260
10286
10314
10400
10434
10622
10644
10674
10691
10722
10742
10756
10790
10884
10956
10995
11025
11094
11154
11191
11198
11238
11410
11500
11569
11572
11596
11600
11614
11644
11790
11796
11864
11882
11889
12078
12099
12157
12173
12174
12181
12232
12247
12259
12347
12358
12470
12493
12548
12572
12588
12650
12692
12714
12736
12800
12877
12891
12898
12937
12978
12986
13021
13078
13120
13134
13145
13182
13185
13234
13248
13253
13305
13318
13328
13345
13376
13394
13404
13428
13460
13492
13532
13533
13538
13562
13642
13683
13829
13831
13909
13940
13948
13964
13970
14028
14070
14074
14078
14087
14096
14100
14144
14199
14233
14281
14341
14376
14534
14544
14623
14639
14664
14689
14707
14737
14748
14755
14764
14828
14878
14893
14894
14994
15024
15070
15116
15144
15174
15178
15216
15252
15309
15338
15371
15386
15464
15513
15519
15526
15548
15558
15579
15594
15637
15663
15676
15698
15742
15771
15791
15855
15893
15894
15969
15975
16011
16084
16092
16116
16126
16164
16174
16210
16219
16347
16353
16357
16358
16374
16384
16406
16424
16446
16511
16558
16592
16678
16692
16748
16767
16832
16845
16940
16959
16980
17001
17024
17239
17384
17411
17494
17522
17540
17543
17544
17545
17683
17712
17787
17794
17890
17908
17940
18090
18214
18234
18342
18414
18416
18417
18440
18542
18608
18796
18966
18968
19089
19108
19130
19173
19207
19272
19279
19318
19328
19360
19374
19396
19412
19430
19438
19590
19702
19752
19846
19868
19888
19892
19942
19945
19959
19973
19996
20006
20033
20097
20109
20120
20171
20181
20186
20355
20412
20430
20476
20547
20569
20648
20758
20784
20853
20855
20903
20967
20978
21034
21112
21191
21204
21206
21341
21346
21359
21385
21409
21520
21548
21550
21558
21588
21688
21707
21728
21734
21850
21948
21956
22030
22033
22053
22114
22123
22130
22156
22166
22175
22193
22205
22210
22235
22272
22302
22374
22379
22426
22436
22465
22470
22492
22495
22559
22562
22609
22615
22633
22654
22686
22692
22704
22746
22784
22830
22866
23024
23209
23218
23252
23264
23322
23332
23336
23385
23400
23434
23452
23458
23482
23490
23516
23524
23546
23566
23567
23582
23629
23640
23663
23690
23733
23753
23806
23824
23860
23861
23874
23912
23927
24052
24066
24078
24097
24180
24202
24217
24235
24250
24261
24275
24310
24360
24361
24402
24408
24426
24504
24579
24608
24613
24614
24630
24660
24716
24725
24752
24864
24887
24934
25096
25120
25192
25259
25273
25295
25306
25309
25315
25359
25457
25508
25520
25620
25622
25678
25736
25936
25954
25959
25966
25986
26017
26020
26034
26050
26072
26150
26206
26240
26358
26398
26493
26538
26636
26733
26804
26842
26913
26918
27000
27009
27043
27053
27062
27086
27113
27164
27176
27185
27193
27248
27263
27282
27340
27375
27428
27502
27576
27602
27877
27954
28016
28152
28178
28254
28279
28338
28375
28395
28413
28420
28491
28504
28512
28525
28587
28589
28608
28617
28632
28637
28654
28657
28710
28716
28717
28720
28732
28770
28777
28778
28788
28846
28854
28886
28952
29006
29029
29067
29089
29096
29155
29156
29157
29204
29226
29229
29246
29264
29308
29310
29359
29410
29462
29494
29524
29659
29705
29712
29760
29771
29816
29817
29830
29832
29834
29890
29924
29933
29937
30029
30049
30125
30163
30216
30262
30312
30390
30416
30418
30482
30514
30626
30630
30689
30725
30764
30882
30895
30897
30900
30915
30961
31000
31008
31022
31028
31031
31112
31136
31242
31252
31266
31273
31290
31301
31335
31392
31402
31413
31415
31418
31444
31452
31562
31577
31650
31682
31723
31738
31782
31900
31961
32040
32080
32132
32141
32156
32168
32248
32277
32343
32344
32345
32348
32372
32415
32417
32444
32445
32670
32671
32702
32711
32724
32757
32782
32786
32822
32954
32971
33001
33089
33121
33130
33181
33206
33216
33242
33245
33340
33437
33484
33598
33602
33611
33674
33772
33892
33993
34034
34050
34115
34122
34270
34536
34554
34602
34622
34692
34711
34765
34800
34811
34817
34829
34836
34841
34882
34996
35062
35114
35147
35224
35228
35272
35310
35326
35397
35406
35420
35486
35504
35574
35582
35591
35622
35654
35730
35740
35784
35787
35898
35943
35977
36015
36047
36065
36118
36126
36214
36256
36268
36270
36332
36407
36439
36452
36482
36491
36527
36539
36598
36599
36642
36716
36732
36746
36750
36774
36839
36890
36900
36984
37004
37050
37069
37070
37089
37135
37154
37160
37194
37216
37249
37256
37275
37313
37341
37358
37379
37407
37409
37413
37414
37438
37453
37465
37552
37557
37558
37567
37574
37582
37584
37590
37606
37680
37842
37864
37983
38038
38165
38194
38218
38250
38280
38298
38321
38362
38408
38508
38511
38553
38577
38604
38608
38626
38640
38658
38670
38675
38678
38696
38698
38780
38848
38864
38932
39122
39208
39230
39377
39378
39405
39409
39431
39473
39502
39513
39516
39526
39573
39700
39714
39773
39793
39832
39904
39928
39945
39951
39972
39978
39999
40057
40125
40127
40138
40144
40326
40335
40431
40495
40511
40514
40558
40573
40587
40599
40620
40669
40688
40690
40704
40720
40739
40752
40966
40972
41003
41016
41018
41072
41094
41188
41199
41200
41270
41312
41332
41370
41390
41416
41435
41464
41475
41509
41618
41624
41627
41713
41789
41882
41898
41904
41933
41949
41951
42038
42045
42078
42108
42121
42141
42193
42221
42229
42232
42237
42287
42297
42312
42317
42321
42331
42334
42337
42338
42364
42392
42415
42428
42508
42509
42554
42555
42609
42658
42661
42665
42682
42687
42705
42756
42827
42878
42898
42906
42984
42992
43098
43111
43127
43148
43206
43212
43224
43328
43368
43385
43387
43402
43510
43680
43702
43772
43837
43898
43946
44013
44072
44078
44099
44135
44147
44174
44253
44261
44262
44307
44321
44350
44357
44381
44383
44406
44462
44494
44554
44555
44625
44630
44662
44706
44713
44748
44753
44760
44819
44830
44870
44907
44910
44968
45012
45024
45026
45057
45108
45122
45134
45226
45250
45256
45282
45292
45344
45443
45477
45491
45526
45582
45602
45607
45616
45764
45770
45797
45836
45875
45894
45899
45903
45945
45954
45963
45985
45991
46070
46073
46088
46097
46110
46188
46203
46300
46331
46369
46370
46406
46423
46435
46441
46458
46588
46612
46647
46753
46878
46882
46901
46941
46948
47016
47020
47022
47077
47104
47115
47296
47319
47357
47358
47376
47381
47389
47413
47414
47449
47471
47475
47603
47618
47624
47655
47662
47718
47737
47741
47781
47828
47873
47934
47938
47940
47949
47955
48064
48324
48352
48397
48469
48482
48507
48541
48564
48603
48628
48649
48662
48680
48800
48895
48905
48966
48971
49006
49040
49094
49107
49140
49188
49202
49300
49378
49410
49460
49500
49508
49517
49527
49668
49791
49890
49892
49895
49905
49957
49992
50010
50042
50071
50155
50223
50279
50317
50320
50365
50366
50379
50383
50408
50435
50470
50481
50560
50574
50618
50718
50734
50856
50881
50936
51196
51227
51231
51251
51274
51283
51298
51349
51367
51371
51416
51492
51498
51626
51659
51679
51736
51737
51740
51764
51789
51794
51795
51797
51809
51886
51955
51963
51990
51999
52004
52038
52074
52086
52096
52106
52129
52131
52157
52172
52195
52252
52272
52273
52364
52394
52399
52404
52452
52624
52686
52767
52792
52857
52870
52875
52892
52895
52910
52914
52947
52962
52981
53015
53024
53063
53122
53152
53198
53201
53248
53316
53356
53410
53466
53470
53474
53507
53538
53552
53556
53578
53585
53598
53638
53641
53691
53754
53844
53854
53909
53920
53975
53978
53995
54002
54011
54024
54044
54091
54092
54178
54184
54230
54366
54388
54401
54461
54470
54528
54630
54654
54832
54858
54880
54883
54904
54913
54914
54928
54932
54949
54975
55002
55042
55055
55060
55116
55153
55208
55322
55396
55428
55442
55460
55513
55592
55631
55685
55729
55739
55876
55878
55998
56001
56014
56054
56124
56203
56224
56229
56366
56396
56468
56476
56489
56492
56586
56666
56732
56838
56866
56890
56939
56962
56990
56995
56999
57156
57211
57231
57240
57242
57256
57302
57311
57327
57345
57369
57383
57392
57409
57462
57472
57510
57527
57566
57598
57662
57743
57768
57812
57814
57837
57864
57982
57988
58064
58119
58130
58154
58537
58646
58653
58802
58803
58817
58845
58871
58976
59196
59286
59294
59358
59368
59390
59400
59426
59446
59459
59563
59574
59584
59701
59718
59719
59726
59731
59747
59766
y_train_pred[48], y_train_5[48]
(True, False)
5가 아닌데, 5라고 예측한 경우
some_digit = X_train[48]
#decision_function 은 예측값이 decision surface와 얼마나 떨어져 있는지를 반환한다.
y_scores = log_clf.decision_function([some_digit])
y_scores
array([0.22419047])
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off")
save_fig("some_digit_plot")
plt.show()
Saving figure some_digit_plot
threshold = 0
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([ True])
threshold가 0이면, 0.2는 0보다 크니까 True를 반환한 것
threshold = 0.5
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([False])
threshold값을 올리므로써, False Positive가 하나 줄었다. 즉, recall이 줄어들고 precision은 올라감
#전체 데이터에 있어서의 decision function 생성
y_scores = cross_val_predict(log_clf, X_train, y_train_5, cv=3,
method="decision_function")
y_scores.shape
(60000,)
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
precisions.shape
(59897,)
thresholds.shape
(59896,)
thresholds 값에 따라서 60000개 보다 조금 작게 shape이 나온다.
def plot_precision_vs_recall(precisions, recalls):
plt.plot(recalls, precisions, "b-", linewidth=2)
plt.xlabel("Recall", fontsize=16)
plt.ylabel("Precision", fontsize=16)
plt.axis([0, 1, 0, 1])
plt.grid(True)
plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
save_fig("precision_vs_recall_plot")
plt.show()
Saving figure precision_vs_recall_plot
오른쪽으로 갈수록 threshold 가 낮은 경우
다중 분류 (Multiclass Classification)¶
원래의 문제 (10개의 숫자 중에 하나로 판정)
from sklearn.linear_model import LogisticRegression
softmax_reg = LogisticRegression(multi_class="multinomial",solver="lbfgs", C=10)
softmax_reg.fit(X_train, y_train)
LogisticRegression(C=10, multi_class='multinomial')
softmax_reg.predict(X_train)[:10]
array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4], dtype=uint8)
from sklearn.metrics import accuracy_score
y_pred = softmax_reg.predict(X_test)
accuracy_score(y_test, y_pred)
0.9243
Data Augmentation¶
성능향상 전략 중 하나 y-label은 바꾸지 않고, x 이미지를 조금씩 움직여서 x 데이터를 늘려서
from scipy.ndimage.interpolation import shift
def shift_image(image, dx, dy):
image = image.reshape((28, 28))
shifted_image = shift(image, [dy, dx], cval=0, mode="constant")
return shifted_image.reshape([-1])
image = X_train[1000]
#이미지 아래로 이동
shifted_image_down = shift_image(image, 0, 5)
#이미지 왼쪽으로 이동
shifted_image_left = shift_image(image, -5, 0)
plt.figure(figsize=(12,3))
plt.subplot(131)
plt.title("Original", fontsize=14)
plt.imshow(image.reshape(28, 28), interpolation="nearest", cmap="Greys")
plt.subplot(132)
plt.title("Shifted down", fontsize=14)
plt.imshow(shifted_image_down.reshape(28, 28), interpolation="nearest", cmap="Greys")
plt.subplot(133)
plt.title("Shifted left", fontsize=14)
plt.imshow(shifted_image_left.reshape(28, 28), interpolation="nearest", cmap="Greys")
plt.show()
전후 좌우 말고도 회전 시키는 방법도 있다.
X_train_augmented = [image for image in X_train]
y_train_augmented = [label for label in y_train]
# 네가지 방법으로 데이터 변환
for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):
for image, label in zip(X_train, y_train):
X_train_augmented.append(shift_image(image, dx, dy))
y_train_augmented.append(label)
X_train_augmented = np.array(X_train_augmented)
y_train_augmented = np.array(y_train_augmented)
X_train_augmented.shape
(300000, 784)
# 학습에 문제가 있을 수 있기 때문에 augmented data 섞어준다
shuffle_idx = np.random.permutation(len(X_train_augmented))
X_train_augmented = X_train_augmented[shuffle_idx]
y_train_augmented = y_train_augmented[shuffle_idx]
X_train_augmented.shape, X_train.shape
((300000, 784), (60000, 784))
softmax_reg_augmented = LogisticRegression(multi_class="multinomial",solver="lbfgs", C=10)
softmax_reg_augmented.fit(X_train_augmented, y_train_augmented)
LogisticRegression(C=10, multi_class='multinomial')
y_pred = softmax_reg_augmented.predict(X_test)
accuracy_score(y_test, y_pred)
0.9279
augmentation해서 정확도가 조금 올라갔다.
Titanic 데이터셋¶
import numpy as np
import pandas as pd
train_data = pd.read_csv("titanic.csv")
train_data.head()
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
속성들
- Survived: that's the target, 0 means the passenger did not survive, while 1 means he/she survived.
- Pclass: passenger class.
- Name, Sex, Age: self-explanatory
- SibSp: how many siblings & spouses of the passenger aboard the Titanic.
- Parch: how many children & parents of the passenger aboard the Titanic.
- Ticket: ticket id
- Fare: price paid (in pounds)
- Cabin: passenger's cabin number
- Embarked: where the passenger embarked the Titanic
식별자를 feature로 넣어서 학습시키게 되면 새로운 데이터가 들어왔을 때 예측값이 굉장히 안좋아질 수 있는 가능성이 있다. 식별자적 속성이 있는 변수는 feature로 넣지 않도록
train_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
Age, Cabin, Embarked 속성들이 missing value를 가지고 있다.
Cabin, Name, Ticket 속성들은 무시한다.
train_data.describe()
PassengerId | Survived | Pclass | Age | SibSp | Parch | Fare | |
---|---|---|---|---|---|---|---|
count | 891.000000 | 891.000000 | 891.000000 | 714.000000 | 891.000000 | 891.000000 | 891.000000 |
mean | 446.000000 | 0.383838 | 2.308642 | 29.699118 | 0.523008 | 0.381594 | 32.204208 |
std | 257.353842 | 0.486592 | 0.836071 | 14.526497 | 1.102743 | 0.806057 | 49.693429 |
min | 1.000000 | 0.000000 | 1.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
25% | 223.500000 | 0.000000 | 2.000000 | 20.125000 | 0.000000 | 0.000000 | 7.910400 |
50% | 446.000000 | 0.000000 | 3.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 |
75% | 668.500000 | 1.000000 | 3.000000 | 38.000000 | 1.000000 | 0.000000 | 31.000000 |
max | 891.000000 | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
오직 40% 미만이 생존했음을 알 수 있다.
train_data["Survived"].value_counts()
0 549
1 342
Name: Survived, dtype: int64
Categorical 속성들을 조사해보자.
train_data["Pclass"].value_counts()
3 491
1 216
2 184
Name: Pclass, dtype: int64
train_data["Sex"].value_counts()
male 577
female 314
Name: Sex, dtype: int64
train_data["Embarked"].value_counts()
S 644
C 168
Q 77
Name: Embarked, dtype: int64
from sklearn.base import BaseEstimator, TransformerMixin
#원하는 속성만 선택해서 쓰기 위해서
class DataFrameSelector(BaseEstimator, TransformerMixin):
def __init__(self, attribute_names):
self.attribute_names = attribute_names
def fit(self, X, y=None):
return self
def transform(self, X):
return X[self.attribute_names]
Numerical 속성을 처리하는 pipeline을 만든다.
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
num_pipeline = Pipeline([
("select_numeric", DataFrameSelector(["Age", "SibSp", "Parch", "Fare"])),
#결측치를 중간값으로 넣음
("imputer", SimpleImputer(strategy="median")),
])
num_pipeline.fit_transform(train_data)
array([[22. , 1. , 0. , 7.25 ],
[38. , 1. , 0. , 71.2833],
[26. , 0. , 0. , 7.925 ],
...,
[28. , 1. , 2. , 23.45 ],
[26. , 0. , 0. , 30. ],
[32. , 0. , 0. , 7.75 ]])
# 결측치를 최빈값으로 넣음
class MostFrequentImputer(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
self.most_frequent_ = pd.Series([X[c].value_counts().index[0] for c in X],
index=X.columns)
return self
def transform(self, X, y=None):
return X.fillna(self.most_frequent_)
from sklearn.preprocessing import OneHotEncoder
cat_pipeline = Pipeline([
("select_cat", DataFrameSelector(["Pclass", "Sex", "Embarked"])),
("imputer", MostFrequentImputer()),
("cat_encoder", OneHotEncoder(sparse=False)),
])
cat_pipeline.fit_transform(train_data)
array([[0., 0., 1., ..., 0., 0., 1.],
[1., 0., 0., ..., 1., 0., 0.],
[0., 0., 1., ..., 0., 0., 1.],
...,
[0., 0., 1., ..., 0., 0., 1.],
[1., 0., 0., ..., 1., 0., 0.],
[0., 0., 1., ..., 0., 1., 0.]])
cat_pipeline.fit_transform(train_data)[0]
array([0., 0., 1., 0., 1., 0., 0., 1.])
Categorical, numerical 속성들을 통합한다.
from sklearn.pipeline import FeatureUnion
preprocess_pipeline = FeatureUnion(transformer_list=[
("num_pipeline", num_pipeline),
("cat_pipeline", cat_pipeline),
])
X_train = preprocess_pipeline.fit_transform(train_data)
X_train
array([[22., 1., 0., ..., 0., 0., 1.],
[38., 1., 0., ..., 1., 0., 0.],
[26., 0., 0., ..., 0., 0., 1.],
...,
[28., 1., 2., ..., 0., 0., 1.],
[26., 0., 0., ..., 1., 0., 0.],
[32., 0., 0., ..., 0., 1., 0.]])
X_train.shape
(891, 12)
목표값 벡터
y_train = train_data["Survived"]
y_train
0 0
1 1
2 1
3 1
4 0
..
886 0
887 1
888 0
889 1
890 0
Name: Survived, Length: 891, dtype: int64
log_clf = LogisticRegression(random_state=0).fit(X_train, y_train)
#모델이 predict한 값(score)와 목표값, input속성들을 concatenate함
a = np.c_[log_clf.decision_function(X_train), y_train, X_train]
df = pd.DataFrame(data=a, columns=["Score", "Survived", "Age", "SibSp", "Parch", "Fare", "Pclass_1", "Pclass_2", "Pclass_3", "Female", "Male", "Embarked_C", "Embarked_Q", "Embarked_S"])
df
Score | Survived | Age | SibSp | Parch | Fare | Pclass_1 | Pclass_2 | Pclass_3 | Female | Male | Embarked_C | Embarked_Q | Embarked_S | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -2.333812 | 0.0 | 22.0 | 1.0 | 0.0 | 7.2500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
1 | 2.346548 | 1.0 | 38.0 | 1.0 | 0.0 | 71.2833 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
2 | 0.483770 | 1.0 | 26.0 | 0.0 | 0.0 | 7.9250 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
3 | 1.997652 | 1.0 | 35.0 | 1.0 | 0.0 | 53.1000 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
4 | -2.490752 | 0.0 | 35.0 | 0.0 | 0.0 | 8.0500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
886 | -1.020055 | 0.0 | 27.0 | 0.0 | 0.0 | 13.0000 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
887 | 2.829704 | 1.0 | 19.0 | 0.0 | 0.0 | 30.0000 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
888 | -0.041455 | 0.0 | 28.0 | 1.0 | 2.0 | 23.4500 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
889 | 0.333715 | 1.0 | 26.0 | 0.0 | 0.0 | 30.0000 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 |
890 | -2.051100 | 0.0 | 32.0 | 0.0 | 0.0 | 7.7500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 |
891 rows × 14 columns
df.sort_values(by=['Score'], ascending=False)[:20]
Score | Survived | Age | SibSp | Parch | Fare | Pclass_1 | Pclass_2 | Pclass_3 | Female | Male | Embarked_C | Embarked_Q | Embarked_S | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
258 | 4.041730 | 1.0 | 35.0 | 0.0 | 0.0 | 512.3292 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
700 | 3.526282 | 1.0 | 18.0 | 1.0 | 0.0 | 227.5250 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
689 | 3.407084 | 1.0 | 15.0 | 0.0 | 1.0 | 211.3375 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
329 | 3.334674 | 1.0 | 16.0 | 0.0 | 1.0 | 57.9792 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
297 | 3.303046 | 0.0 | 2.0 | 1.0 | 2.0 | 151.5500 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
307 | 3.220966 | 1.0 | 17.0 | 1.0 | 0.0 | 108.9000 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
310 | 3.206418 | 1.0 | 24.0 | 0.0 | 0.0 | 83.1583 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
369 | 3.166488 | 1.0 | 24.0 | 0.0 | 0.0 | 69.3000 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
641 | 3.166488 | 1.0 | 24.0 | 0.0 | 0.0 | 69.3000 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
306 | 3.140391 | 1.0 | 28.0 | 0.0 | 0.0 | 110.8833 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
311 | 3.129693 | 1.0 | 18.0 | 2.0 | 2.0 | 262.3750 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
716 | 3.111692 | 1.0 | 38.0 | 0.0 | 0.0 | 227.5250 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
710 | 3.109451 | 1.0 | 24.0 | 0.0 | 0.0 | 49.5042 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
504 | 3.101930 | 1.0 | 16.0 | 0.0 | 0.0 | 86.5000 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
291 | 3.096663 | 1.0 | 19.0 | 1.0 | 0.0 | 91.0792 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
708 | 3.070492 | 1.0 | 22.0 | 0.0 | 0.0 | 151.5500 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 |
537 | 3.054589 | 1.0 | 30.0 | 0.0 | 0.0 | 106.4250 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
256 | 3.049102 | 1.0 | 28.0 | 0.0 | 0.0 | 79.2000 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
742 | 3.020260 | 1.0 | 21.0 | 2.0 | 2.0 | 262.3750 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
393 | 3.014705 | 1.0 | 23.0 | 1.0 | 0.0 | 113.2750 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 |
스코어가 높을수록 대체적으로 실제로 살아남았다
df.sort_values(by=['Score'])[:20]
Score | Survived | Age | SibSp | Parch | Fare | Pclass_1 | Pclass_2 | Pclass_3 | Female | Male | Embarked_C | Embarked_Q | Embarked_S | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
159 | -4.759969 | 0.0 | 28.0 | 8.0 | 2.0 | 69.5500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
324 | -4.759969 | 0.0 | 28.0 | 8.0 | 2.0 | 69.5500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
201 | -4.759969 | 0.0 | 28.0 | 8.0 | 2.0 | 69.5500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
846 | -4.759969 | 0.0 | 28.0 | 8.0 | 2.0 | 69.5500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
851 | -3.914177 | 0.0 | 74.0 | 0.0 | 0.0 | 7.7750 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
116 | -3.455494 | 0.0 | 70.5 | 0.0 | 0.0 | 7.7500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 |
326 | -3.444396 | 0.0 | 61.0 | 0.0 | 0.0 | 6.2375 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
683 | -3.369645 | 0.0 | 14.0 | 5.0 | 2.0 | 46.9000 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
94 | -3.368523 | 0.0 | 59.0 | 0.0 | 0.0 | 7.2500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
13 | -3.339799 | 0.0 | 39.0 | 1.0 | 5.0 | 31.2750 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
860 | -3.322094 | 0.0 | 41.0 | 2.0 | 0.0 | 14.1083 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
360 | -3.294984 | 0.0 | 40.0 | 1.0 | 4.0 | 27.9000 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
59 | -3.260211 | 0.0 | 11.0 | 5.0 | 2.0 | 46.9000 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
280 | -3.254866 | 0.0 | 65.0 | 0.0 | 0.0 | 7.7500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 |
152 | -3.238546 | 0.0 | 55.5 | 0.0 | 0.0 | 8.0500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
176 | -3.221140 | 0.0 | 28.0 | 3.0 | 1.0 | 25.4667 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
104 | -3.193999 | 0.0 | 37.0 | 2.0 | 0.0 | 7.9250 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
480 | -3.187256 | 0.0 | 9.0 | 5.0 | 2.0 | 46.9000 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
631 | -3.077265 | 0.0 | 51.0 | 0.0 | 0.0 | 7.0542 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
406 | -3.075261 | 0.0 | 51.0 | 0.0 | 0.0 | 7.7500 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 |
스코어가 낮을수록 살아남지 못했다.
'Machine Learning' 카테고리의 다른 글
[AI class w8d4] 신경망의 기초 - 기계학습과 수학 (0) | 2021.06.18 |
---|---|
[AI class w8d3] 신경망의 기초 - 인공지능과 기계학습 소개 (0) | 2021.06.18 |
[AI class w7d4] Linear Model for Classification 선형분류 TIL (0) | 2021.06.14 |
[AI class w7d3] Linear Model for Regression 선형회귀 TIL (0) | 2021.06.09 |
[AI class w6d5] Week6 과제 ML Basics 실습 (0) | 2021.06.07 |