파이썬을 활용한 이커머스 데이터분석_강의를 듣고 따라했던 코딩과 요점을 정리하였다.
- 출처: fast campus
Chapter.05 구매 요인 분석 (Dicision Tree)¶
분석의 목적¶
디시젼 트리 모델을 통하여 온라인 경매 아이템 판매여부를 예측하고 각 변수의 영향도를 확인
Binary Classification에 속함.
데이터는 온라인 경매 사이트
- 아이템은 갤럭시 휴대폰: 어떠한 특성을 가진 휴대폰이 잘 팔리는지? 새로운 아이템이 얼마나 잘 팔릴지 예측, 그리고 더 나아가서 어떠한 변수들이 판매에 영향을 미치나 살펴볼 것이다.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
data = pd.read_csv('./data/galaxy.csv')
data
BuyItNow | startprice | carrier | color | productline | noDescription | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | None | White | Galaxy_S9 | contains description | 0 | 0 | 1 |
1 | 0 | 235.00 | None | NaN | Galaxy_Note9 | contains description | 0 | 0 | 0 |
2 | 0 | 199.99 | NaN | NaN | Unknown | no description | 100 | 2 | 0 |
3 | 1 | 175.00 | AT&T | Space Gray | Galaxy_Note9 | contains description | 0 | 0 | 1 |
4 | 1 | 100.00 | None | Space Gray | Galaxy_S8 | contains description | 0 | 0 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1480 | 0 | 89.50 | AT&T | NaN | Galaxy_S7 | no description | 96 | 2 | 0 |
1481 | 0 | 239.95 | None | Midnight Black | Galaxy_S9 | no description | 97 | 5 | 1 |
1482 | 0 | 329.99 | None | Space Gray | Galaxy_Note10 | no description | 93 | 1 | 0 |
1483 | 0 | 89.00 | None | Midnight Black | Galaxy_S7 | no description | 92 | 2 | 1 |
1484 | 0 | 119.99 | AT&T | Midnight Black | Galaxy_S7 | no description | 96 | 5 | 0 |
1485 rows × 9 columns
늘 그렇듯, 우선 data를 불러온다. 그리고 살펴 보자. (칼럼 및 특성) [경매 사이트] 임을 참고
- BuyItNow : 바로 구매 할 수 있는지의 여부
- startprice : 경매의 시작 가격
- carrier : 미국 통신사 이름들
- color : 기기 색상
- productline : 모델명
- noDescription : 판매자가 설명을 썼는지 안썼는지
- charCountDescriptio : 설명이 얼마나 긴지
- upperCaseDescription : 몇 문장인지
- sold : 팔렸는지 안팔렸는지 (우리가 알고자하는 종속변수)
data.head()
BuyItNow | startprice | carrier | color | productline | noDescription | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | None | White | Galaxy_S9 | contains description | 0 | 0 | 1 |
1 | 0 | 235.00 | None | NaN | Galaxy_Note9 | contains description | 0 | 0 | 0 |
2 | 0 | 199.99 | NaN | NaN | Unknown | no description | 100 | 2 | 0 |
3 | 1 | 175.00 | AT&T | Space Gray | Galaxy_Note9 | contains description | 0 | 0 | 1 |
4 | 1 | 100.00 | None | Space Gray | Galaxy_S8 | contains description | 0 | 0 | 1 |
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1485 entries, 0 to 1484 Data columns (total 9 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 BuyItNow 1485 non-null int64 1 startprice 1485 non-null float64 2 carrier 1179 non-null object 3 color 892 non-null object 4 productline 1485 non-null object 5 noDescription 1485 non-null object 6 charCountDescription 1485 non-null int64 7 upperCaseDescription 1485 non-null int64 8 sold 1485 non-null int64 dtypes: float64(1), int64(4), object(4) memory usage: 104.5+ KB
data.describe()
BuyItNow | startprice | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|
count | 1485.000000 | 1485.000000 | 1485.000000 | 1485.000000 | 1485.000000 |
mean | 0.449158 | 216.844162 | 31.184512 | 2.863300 | 0.461953 |
std | 0.497576 | 172.893308 | 41.744518 | 9.418585 | 0.498718 |
min | 0.000000 | 0.010000 | 0.000000 | 0.000000 | 0.000000 |
25% | 0.000000 | 80.000000 | 0.000000 | 0.000000 | 0.000000 |
50% | 0.000000 | 198.000000 | 0.000000 | 0.000000 | 0.000000 |
75% | 1.000000 | 310.000000 | 79.000000 | 2.000000 | 1.000000 |
max | 1.000000 | 999.000000 | 111.000000 | 81.000000 | 1.000000 |
그래프를 통해 데이터를 좀더 살펴 보자
sns.distplot(data['startprice'])
/home/ubuntu/.local/lib/python3.6/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
<AxesSubplot:xlabel='startprice', ylabel='Density'>
낮은 가격대부터 시작하는 것이 많은 것을 확인
sns.distplot(data['charCountDescription'])
/home/ubuntu/.local/lib/python3.6/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
<AxesSubplot:xlabel='charCountDescription', ylabel='Density'>
0 이 많음을 알 수 있고, 그다음이 100 임을 확인
그리고, boxplot 그래프로 다음도 확인해 보자.
plt.figure(figsize= (20,10))
sns.boxplot(x = 'productline', y = 'startprice', data = data)
<AxesSubplot:xlabel='productline', ylabel='startprice'>
대체적으로 모델등급이 최신버전이 시작가격이 높은 것을 확인 할 수 있다.
결측치가 있는지 확인을 하자.
data.isna()
BuyItNow | startprice | carrier | color | productline | noDescription | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|---|---|---|---|
0 | False | False | False | False | False | False | False | False | False |
1 | False | False | False | True | False | False | False | False | False |
2 | False | False | True | True | False | False | False | False | False |
3 | False | False | False | False | False | False | False | False | False |
4 | False | False | False | False | False | False | False | False | False |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1480 | False | False | False | True | False | False | False | False | False |
1481 | False | False | False | False | False | False | False | False | False |
1482 | False | False | False | False | False | False | False | False | False |
1483 | False | False | False | False | False | False | False | False | False |
1484 | False | False | False | False | False | False | False | False | False |
1485 rows × 9 columns
data.isna().sum()
BuyItNow 0 startprice 0 carrier 306 color 593 productline 0 noDescription 0 charCountDescription 0 upperCaseDescription 0 sold 0 dtype: int64
data.isna().sum() / len(data)
BuyItNow 0.000000 startprice 0.000000 carrier 0.206061 color 0.399327 productline 0.000000 noDescription 0.000000 charCountDescription 0.000000 upperCaseDescription 0.000000 sold 0.000000 dtype: float64
결측치가 20%, 39% 로 확인.
결측치를 채워 넣어보자. carrier, color 는 텍스트로 이루어진 컬럼들 이다.
data.head()
BuyItNow | startprice | carrier | color | productline | noDescription | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | None | White | Galaxy_S9 | contains description | 0 | 0 | 1 |
1 | 0 | 235.00 | None | NaN | Galaxy_Note9 | contains description | 0 | 0 | 0 |
2 | 0 | 199.99 | NaN | NaN | Unknown | no description | 100 | 2 | 0 |
3 | 1 | 175.00 | AT&T | Space Gray | Galaxy_Note9 | contains description | 0 | 0 | 1 |
4 | 1 | 100.00 | None | Space Gray | Galaxy_S8 | contains description | 0 | 0 | 1 |
- None 은 결측치는 아닌, None 란 글자로 데이터는 들어가 있는 것이다. NaN 이 결측치 이다.
- NaN을 채워보도록 하자.
data.fillna('Unkown')
BuyItNow | startprice | carrier | color | productline | noDescription | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | None | White | Galaxy_S9 | contains description | 0 | 0 | 1 |
1 | 0 | 235.00 | None | Unkown | Galaxy_Note9 | contains description | 0 | 0 | 0 |
2 | 0 | 199.99 | Unkown | Unkown | Unknown | no description | 100 | 2 | 0 |
3 | 1 | 175.00 | AT&T | Space Gray | Galaxy_Note9 | contains description | 0 | 0 | 1 |
4 | 1 | 100.00 | None | Space Gray | Galaxy_S8 | contains description | 0 | 0 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1480 | 0 | 89.50 | AT&T | Unkown | Galaxy_S7 | no description | 96 | 2 | 0 |
1481 | 0 | 239.95 | None | Midnight Black | Galaxy_S9 | no description | 97 | 5 | 1 |
1482 | 0 | 329.99 | None | Space Gray | Galaxy_Note10 | no description | 93 | 1 | 0 |
1483 | 0 | 89.00 | None | Midnight Black | Galaxy_S7 | no description | 92 | 2 | 1 |
1484 | 0 | 119.99 | AT&T | Midnight Black | Galaxy_S7 | no description | 96 | 5 | 0 |
1485 rows × 9 columns
data = data.fillna('Unkown') # data란 이름으로 할당해주자.
data
BuyItNow | startprice | carrier | color | productline | noDescription | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | None | White | Galaxy_S9 | contains description | 0 | 0 | 1 |
1 | 0 | 235.00 | None | Unkown | Galaxy_Note9 | contains description | 0 | 0 | 0 |
2 | 0 | 199.99 | Unkown | Unkown | Unknown | no description | 100 | 2 | 0 |
3 | 1 | 175.00 | AT&T | Space Gray | Galaxy_Note9 | contains description | 0 | 0 | 1 |
4 | 1 | 100.00 | None | Space Gray | Galaxy_S8 | contains description | 0 | 0 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1480 | 0 | 89.50 | AT&T | Unkown | Galaxy_S7 | no description | 96 | 2 | 0 |
1481 | 0 | 239.95 | None | Midnight Black | Galaxy_S9 | no description | 97 | 5 | 1 |
1482 | 0 | 329.99 | None | Space Gray | Galaxy_Note10 | no description | 93 | 1 | 0 |
1483 | 0 | 89.00 | None | Midnight Black | Galaxy_S7 | no description | 92 | 2 | 1 |
1484 | 0 | 119.99 | AT&T | Midnight Black | Galaxy_S7 | no description | 96 | 5 | 0 |
1485 rows × 9 columns
NaN값이 Unkown으로 변경됨을 확인. 유니크한 값도 확인해 보자.
data['carrier'].value_counts()
None 863 Unkown 306 AT&T 177 Verizon 87 Sprint/T-Mobile 52 Name: carrier, dtype: int64
카테고리 data 처리하기.¶
data[['carrier','color','productline','noDescription']]
carrier | color | productline | noDescription | |
---|---|---|---|---|
0 | None | White | Galaxy_S9 | contains description |
1 | None | Unkown | Galaxy_Note9 | contains description |
2 | Unkown | Unkown | Unknown | no description |
3 | AT&T | Space Gray | Galaxy_Note9 | contains description |
4 | None | Space Gray | Galaxy_S8 | contains description |
... | ... | ... | ... | ... |
1480 | AT&T | Unkown | Galaxy_S7 | no description |
1481 | None | Midnight Black | Galaxy_S9 | no description |
1482 | None | Space Gray | Galaxy_Note10 | no description |
1483 | None | Midnight Black | Galaxy_S7 | no description |
1484 | AT&T | Midnight Black | Galaxy_S7 | no description |
1485 rows × 4 columns
data[['carrier','color','productline','noDescription']].nunique()
carrier 5 color 8 productline 8 noDescription 2 dtype: int64
data['carrier'].value_counts()
None 863 Unkown 306 AT&T 177 Verizon 87 Sprint/T-Mobile 52 Name: carrier, dtype: int64
data['color'].value_counts()
Unkown 593 White 328 Midnight Black 274 Space Gray 180 Gold 52 Black 38 Aura Black 19 Prism Black 1 Name: color, dtype: int64
data['productline'].value_counts()
Galaxy_Note10 351 Galaxy_S8 277 Galaxy_S7 227 Unknown 204 Galaxy_S9 158 Galaxy_Note8 153 Galaxy_Note9 107 Galaxy_S10 8 Name: productline, dtype: int64
data['noDescription'].value_counts()
contains description 856 no description 629 Name: noDescription, dtype: int64
color 에서 Black 종류가 총 4가지인데 그냥 Black 하나로 만들어 보겠다.
def black(x):
if x in ['Midnight Black', 'Aura Black', 'Prism Black']:
return 'Black'
else:
return x
data['color'].apply(lambda x: black(x))
0 White 1 Unkown 2 Unkown 3 Space Gray 4 Space Gray ... 1480 Unkown 1481 Black 1482 Space Gray 1483 Black 1484 Black Name: color, Length: 1485, dtype: object
data['color'].apply(lambda x: black(x)).value_counts()
Unkown 593 Black 332 White 328 Space Gray 180 Gold 52 Name: color, dtype: int64
Black 종류를 하나의 Black로 만들었다.
data ['color'] = data['color'].apply(lambda x: black(x))
data['color'].value_counts()
Unkown 593 Black 332 White 328 Space Gray 180 Gold 52 Name: color, dtype: int64
data.head(3)
BuyItNow | startprice | carrier | color | productline | noDescription | charCountDescription | upperCaseDescription | sold | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | None | White | Galaxy_S9 | contains description | 0 | 0 | 1 |
1 | 0 | 235.00 | None | Unkown | Galaxy_Note9 | contains description | 0 | 0 | 0 |
2 | 0 | 199.99 | Unkown | Unkown | Unknown | no description | 100 | 2 | 0 |
pd.get_dummies(data, columns = ['carrier', 'color', 'productline', 'noDescription'], drop_first = True)
BuyItNow | startprice | charCountDescription | upperCaseDescription | sold | carrier_None | carrier_Sprint/T-Mobile | carrier_Unkown | carrier_Verizon | color_Gold | ... | color_Unkown | color_White | productline_Galaxy_Note8 | productline_Galaxy_Note9 | productline_Galaxy_S10 | productline_Galaxy_S7 | productline_Galaxy_S8 | productline_Galaxy_S9 | productline_Unknown | noDescription_no description | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
1 | 0 | 235.00 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 0 | 199.99 | 100 | 2 | 0 | 0 | 0 | 1 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 |
3 | 1 | 175.00 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 1 | 100.00 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1480 | 0 | 89.50 | 96 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
1481 | 0 | 239.95 | 97 | 5 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 |
1482 | 0 | 329.99 | 93 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
1483 | 0 | 89.00 | 92 | 2 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
1484 | 0 | 119.99 | 96 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
1485 rows × 21 columns
data = pd.get_dummies(data, columns = ['carrier', 'color', 'productline', 'noDescription'], drop_first = True)
data
BuyItNow | startprice | charCountDescription | upperCaseDescription | sold | carrier_None | carrier_Sprint/T-Mobile | carrier_Unkown | carrier_Verizon | color_Gold | ... | color_Unkown | color_White | productline_Galaxy_Note8 | productline_Galaxy_Note9 | productline_Galaxy_S10 | productline_Galaxy_S7 | productline_Galaxy_S8 | productline_Galaxy_S9 | productline_Unknown | noDescription_no description | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 199.99 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
1 | 0 | 235.00 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 0 | 199.99 | 100 | 2 | 0 | 0 | 0 | 1 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 |
3 | 1 | 175.00 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 1 | 100.00 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1480 | 0 | 89.50 | 96 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
1481 | 0 | 239.95 | 97 | 5 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 |
1482 | 0 | 329.99 | 93 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
1483 | 0 | 89.00 | 92 | 2 | 1 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
1484 | 0 | 119.99 | 96 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
1485 rows × 21 columns
이제 데이터 전처리가 끝났으니, train을 해보자.
from sklearn.model_selection import train_test_split
X = data.drop('sold', axis = 1)
y = data['sold']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2 , random_state= 100)
모델링을 위해 모듈을 임포트 하자.
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier(max_depth = 10)
model.fit(X_train, y_train)
DecisionTreeClassifier(max_depth=10)
model.predict(X_test)
array([1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1])
pred = model.predict(X_test)
우리가 가지고 있는 답안지는 y_test 임으로 예측한 테스트와 비교해 보자.
y_test
258 1 57 0 225 1 704 0 1096 0 .. 44 0 1399 1 1035 0 259 1 532 1 Name: sold, Length: 297, dtype: int64
from sklearn.metrics import accuracy_score, confusion_matrix
accuracy_score(y_test, pred)
0.8080808080808081
max_depth 의 값을 조정해서 정확도를 좀더 높여 보려한다.
for i in range(2,31):
model = DecisionTreeClassifier(max_depth = i)
model.fit(X_train, y_train)
pred = model.predict(X_test)
print(i , round(accuracy_score(y_test, pred), 4)) # 소수점4자리수까지 보이게 해주자
2 0.8182 3 0.8215 4 0.8215 5 0.8182 6 0.8081 7 0.8013 8 0.8081 9 0.8047 10 0.8081 11 0.7879 12 0.7778 13 0.7744 14 0.7576 15 0.7677 16 0.7744 17 0.7744 18 0.7576 19 0.7677 20 0.7542 21 0.7609 22 0.7441 23 0.7508 24 0.7542 25 0.7576 26 0.7441 27 0.7407 28 0.7508 29 0.7475 30 0.7609
for 문으로 (max_depth = 3) 으로 할때 가장 높은 정확도가 나온 것을 확인.
model = DecisionTreeClassifier(max_depth = 3)
model.fit(X_train, y_train)
pred = model.predict(X_test)
accuracy_score(y_test, pred)
0.8215488215488216
이렇게 해서 정확도를 조금 더 높였다.
confusion_matrix(y_test, pred)
array([[148, 13], [ 40, 96]])
Treeplot 을 그려보자¶
from sklearn.tree import plot_tree
plt.figure(figsize=(20,10))
plot_tree(model)
[Text(558.0, 475.65000000000003, 'X[0] <= 0.5\ngini = 0.497\nsamples = 1188\nvalue = [638, 550]'), Text(279.0, 339.75, 'X[1] <= 208.495\ngini = 0.357\nsamples = 659\nvalue = [506, 153]'), Text(139.5, 203.85000000000002, 'X[17] <= 0.5\ngini = 0.442\nsamples = 240\nvalue = [161, 79]'), Text(69.75, 67.94999999999999, 'gini = 0.43\nsamples = 230\nvalue = [158, 72]'), Text(209.25, 67.94999999999999, 'gini = 0.42\nsamples = 10\nvalue = [3, 7]'), Text(418.5, 203.85000000000002, 'X[3] <= 5.5\ngini = 0.291\nsamples = 419\nvalue = [345, 74]'), Text(348.75, 67.94999999999999, 'gini = 0.319\nsamples = 367\nvalue = [294, 73]'), Text(488.25, 67.94999999999999, 'gini = 0.038\nsamples = 52\nvalue = [51, 1]'), Text(837.0, 339.75, 'X[1] <= 142.475\ngini = 0.375\nsamples = 529\nvalue = [132, 397]'), Text(697.5, 203.85000000000002, 'X[1] <= 59.995\ngini = 0.216\nsamples = 332\nvalue = [41, 291]'), Text(627.75, 67.94999999999999, 'gini = 0.108\nsamples = 210\nvalue = [12, 198]'), Text(767.25, 67.94999999999999, 'gini = 0.362\nsamples = 122\nvalue = [29, 93]'), Text(976.5, 203.85000000000002, 'X[1] <= 205.0\ngini = 0.497\nsamples = 197\nvalue = [91, 106]'), Text(906.75, 67.94999999999999, 'gini = 0.452\nsamples = 81\nvalue = [28, 53]'), Text(1046.25, 67.94999999999999, 'gini = 0.496\nsamples = 116\nvalue = [63, 53]')]
몇가지 코드를 추가하여, plot을 눈으로 볼때 더 보기좋게 바꾸어보자.
plt.figure(figsize=(20,10))
plot_tree(model, feature_names = X_train.columns, fontsize = 15, label = 'None', max_depth = 2)
[Text(558.0, 475.65000000000003, 'BuyItNow <= 0.5\n0.497\n1188\n[638, 550]'), Text(279.0, 339.75, 'startprice <= 208.495\n0.357\n659\n[506, 153]'), Text(139.5, 203.85000000000002, 'productline_Galaxy_S9 <= 0.5\n0.442\n240\n[161, 79]'), Text(69.75, 67.94999999999999, '\n (...) \n'), Text(209.25, 67.94999999999999, '\n (...) \n'), Text(418.5, 203.85000000000002, 'upperCaseDescription <= 5.5\n0.291\n419\n[345, 74]'), Text(348.75, 67.94999999999999, '\n (...) \n'), Text(488.25, 67.94999999999999, '\n (...) \n'), Text(837.0, 339.75, 'startprice <= 142.475\n0.375\n529\n[132, 397]'), Text(697.5, 203.85000000000002, 'startprice <= 59.995\n0.216\n332\n[41, 291]'), Text(627.75, 67.94999999999999, '\n (...) \n'), Text(767.25, 67.94999999999999, '\n (...) \n'), Text(976.5, 203.85000000000002, 'startprice <= 205.0\n0.497\n197\n[91, 106]'), Text(906.75, 67.94999999999999, '\n (...) \n'), Text(1046.25, 67.94999999999999, '\n (...) \n')]
feature_names = X_train.columns 으로 독립변수의 컬럼을 나오게 했다.
fontsize = 15 으로 글자 크기 조정
label = 'None' 으로 gini,smaple, 등 글자를 지워 눈으로 보기에 깔끔하게 해주었다.
max_depth = 2 으로 트리 깊이를 2까지만 나오게 설정해 주었다.
- 출처: 패스트캠퍼스_파이썬을 활용한 이커머스 데이터분석
'파이썬을 활용한 이커머스 데이터 분석' 카테고리의 다른 글
Chapter.07 고객 분류 (Kmeans) (0) | 2021.06.14 |
---|---|
Chapter.06 프로모션 효율 예측 (Random Forest) (0) | 2021.06.13 |
Chapter04.KNN (0) | 2021.06.10 |
Chapter03. 광고 반응률 예측 (Logistic Regression) (0) | 2021.06.10 |
Chapter02. 고객별 연간 지출액 예측 (Linear Regression) (0) | 2021.06.08 |