zoukankan      html  css  js  c++  java
  • 跟我学算法-决策树

    决策树算法:主要通过信息熵或者gini系数来作为衡量标准

    当完成决策树时需要进行剪枝操作,在剪枝过程中,我们一般采用预剪枝的操作(该操作更加实用)

    预剪枝过程中的几个限制条件:

                                               1. 限制深度

                                               2. 叶子节点个数

                                               3.叶子节点样本数

                                               4.信息增益量

                                               ..... 

    下面以一个房屋数据为列子

    from sklearn.datasets.california_housing import  fetch_california_housing
    housing = fetch_california_housing()  #导入房屋数据
    
    print(housing.data.shape)
    
    from sklearn import  tree
    dtr = tree.DecisionTreeRegressor(max_depth=2)  #设置数的限制条件
    dtr.fit(housing.data[:, [6,7]], housing.target) #以data数据的第6和第7个特征,以及目标函数建立模型
    
    # 图像可视化
    dot_data = 
    tree.export_graphviz(
        dtr,
        out_file = None,
        feature_names = housing.feature_names[6:8],
        filled = True,
        impurity= False,
        rounded= True,
    )
    
    import pydotplus
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.get_nodes()[7].set_fillcolor("#FFF2DD")  #上色
    from IPython.display import Image
    Image(graph.create_png())
    graph.write_png("dtr_white_background.png")  #图片保存为png格式
    plt.show()

    现在采用所有的变量构建参数,但是速度较慢

    from sklearn.model_selection import train_test_split
    
    data_train, data_test, target_train, target_test = train_test_split(housing.data, housing.target, 
            test_size=0.1, random_state=42) #test_size 表示分割的百分比, random_state表示随机种子,每次随机的结果不变
     # 分割数据
    dtr = tree.DecisionTreeRegressor(random_state=42) #构建树
    dtr.fit(data_train, target_train)  #建立模型
    dtr.score(data_train, target_train) #模型得分,得分越高模型效果越好
    
    #构建决策树图像
    dot_data = 
    tree.export_graphviz(
        dtr, #树的名称
        out_file = None,
        feature_names = housing.feature_names,
        filled = True,
        impurity = False,
        rounded = True,
    )
    import pydotplus
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.get_nodes()[7].set_fillcolor("#FFF2DD")
    from IPython.display import Image
    Image(graph.create_png())
    graph.write_png("dtr_white_background_1.png")  #png分辨率较高
    plt.show()
  • 相关阅读:
    UVA 11859
    [OpenGL]OpenGL坐标系和坐标变换
    树状数组
    编程算法
    乞讨 间隔[a,b]在见面p^k*q*^m(k>m)中数号码
    解析Android的 消息传递机制Handler
    Atitit.故障排除系列---php 计划网站数据库错误排除过程
    Remove Element
    [Angular Directive] Write a Structural Directive in Angular 2
    [Compose] 18. Maintaining structure whilst asyncing
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/9508101.html
Copyright © 2011-2022 走看看