zoukankan      html  css  js  c++  java
  • 吴裕雄 python 机器学习——分类决策树模型

    import numpy as np
    import matplotlib.pyplot as plt
    
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from sklearn.tree import DecisionTreeClassifier,DecisionTreeRegressor
    
    def load_data():
        '''
        加载用于分类问题的数据集。数据集采用 scikit-learn 自带的 iris 数据集
        '''
        # scikit-learn 自带的 iris 数据集
        iris=datasets.load_iris()
        X_train=iris.data
        y_train=iris.target
        return train_test_split(X_train, y_train,test_size=0.25,random_state=0,stratify=y_train)
    
    #分类决策树DecisionTreeClassifier模型
    def test_DecisionTreeClassifier(*data):
        X_train,X_test,y_train,y_test=data
        clf = DecisionTreeClassifier()
        clf.fit(X_train, y_train)
        print("Training score:%f"%(clf.score(X_train,y_train)))
        print("Testing score:%f"%(clf.score(X_test,y_test)))
        
    # 产生用于分类问题的数据集
    X_train,X_test,y_train,y_test=load_data()
    # 调用 test_DecisionTreeClassifier
    test_DecisionTreeClassifier(X_train,X_test,y_train,y_test)

    def test_DecisionTreeClassifier_criterion(*data):
        '''
        测试 DecisionTreeClassifier 的预测性能随 criterion 参数的影响
        '''
        X_train,X_test,y_train,y_test=data
        criterions=['gini','entropy']
        for criterion in criterions:
            clf = DecisionTreeClassifier(criterion=criterion)
            clf.fit(X_train, y_train)
            print("criterion:%s"%criterion)
            print("Training score:%f"%(clf.score(X_train,y_train)))
            print("Testing score:%f"%(clf.score(X_test,y_test)))
            
    # 调用 test_DecisionTreeClassifier_criterion
    test_DecisionTreeClassifier_criterion(X_train,X_test,y_train,y_test)

    def test_DecisionTreeClassifier_splitter(*data):
        '''
        测试 DecisionTreeClassifier 的预测性能随划分类型的影响
        '''
        X_train,X_test,y_train,y_test=data
        splitters=['best','random']
        for splitter in splitters:
            clf = DecisionTreeClassifier(splitter=splitter)
            clf.fit(X_train, y_train)
            print("splitter:%s"%splitter)
            print("Training score:%f"%(clf.score(X_train,y_train)))
            print("Testing score:%f"%(clf.score(X_test,y_test)))
            
    # 调用 test_DecisionTreeClassifier_splitter
    test_DecisionTreeClassifier_splitter(X_train,X_test,y_train,y_test)

    def test_DecisionTreeClassifier_depth(*data,maxdepth):
        '''
        测试 DecisionTreeClassifier 的预测性能随 max_depth 参数的影响
        '''
        X_train,X_test,y_train,y_test=data
        depths=np.arange(1,maxdepth)
        training_scores=[]
        testing_scores=[]
        for depth in depths:
            clf = DecisionTreeClassifier(max_depth=depth)
            clf.fit(X_train, y_train)
            training_scores.append(clf.score(X_train,y_train))
            testing_scores.append(clf.score(X_test,y_test))
    
        ## 绘图
        fig=plt.figure()
        ax=fig.add_subplot(1,1,1)
        ax.plot(depths,training_scores,label="traing score",marker='o')
        ax.plot(depths,testing_scores,label="testing score",marker='*')
        ax.set_xlabel("maxdepth")
        ax.set_ylabel("score")
        ax.set_title("Decision Tree Classification")
        ax.legend(framealpha=0.5,loc='best')
        plt.show()
        
    # 调用 test_DecisionTreeClassifier_depth
    test_DecisionTreeClassifier_depth(X_train,X_test,y_train,y_test,maxdepth=100)

    import os
    import pydotplus
    
    from io import StringIO
    from sklearn.tree import export_graphviz
    from sklearn.tree import DecisionTreeClassifier,DecisionTreeRegressor
    
    X_train,X_test,y_train,y_test=load_data()
    clf = DecisionTreeClassifier()
    clf.fit(X_train,y_train)
    export_graphviz(clf,"F://out")

  • 相关阅读:
    MapServer:地图发布工具
    hdu1054(二分图匹配)
    hdu 5091(线段树+扫描线)
    hdu1828(线段树+扫描线)
    hdu2847(暴力)
    hdu1052(田忌赛马 贪心)
    hdu1051(LIS | Dilworth定理)
    hdu1050(贪心)
    poj 2773(容斥原理)
    hdu 1044(bfs+状压)
  • 原文地址:https://www.cnblogs.com/tszr/p/10791023.html
Copyright © 2011-2022 走看看