반응형
Notice
Recent Posts
Recent Comments
Link
관리 메뉴

bro's coding

sklearn.tree.DecisionTreeClassifier.max_depth 변화 관찰 본문

[AI]/python.sklearn

sklearn.tree.DecisionTreeClassifier.max_depth 변화 관찰

givemebro 2020. 4. 20. 10:44
반응형
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer

cancer =load_breast_cancer()

 

from sklearn.tree import DecisionTreeClassifier

 

X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target)

score_train=[]
score_test=[]

for depth in range(1,10):
    model=DecisionTreeClassifier(max_depth=depth,random_state=2020)
    model.fit(X_train,y_train)
    score1=model.score(X_train,y_train)
    score2=model.score(X_test,y_test)
    score_train.append(score1)
    score_test.append(score2)
plt.plot(range(1,10),score_train,'ro--')
plt.plot(range(1,10),score_test,'bs-')
plt.legend(['train','test'])
plt.xticks(range(1,10),range(1,10))
plt.xlabel('max_depth')

 

import mglearn

col1=0
col2=1

X_train,X_test,y_train,y_test=train_test_split(cancer.data[:,[col1,col2]],cancer.target)

plt.figure(figsize=[14,16])
for depth in range(1,10):
    model=DecisionTreeClassifier(max_depth=depth,random_state=2020)
    model.fit(X_train,y_train)
    
    plt.subplot(3,3,depth)
    plt.title('max_depth = '+str(depth))
    mglearn.plots.plot_2d_classification(model,X_train)
    mglearn.discrete_scatter(X_train[:,col1],X_train[:,col2],y_train,alpha=0.3)

반응형
Comments