zoukankan      html  css  js  c++  java
  • 决策树

    决策树算法:主要通过信息熵或者gini系数来作为衡量标准

    当完成决策树时需要进行剪枝操作,在剪枝过程中,我们一般采用预剪枝的操作(该操作更加实用)

    预剪枝过程中的几个限制条件:

                                               1. 限制深度

                                               2. 叶子节点个数

                                               3.叶子节点样本数

                                               4.信息增益量

                                               ..... 

    下面以一个房屋数据为列子

    复制代码
    from sklearn.datasets.california_housing import  fetch_california_housing
    housing = fetch_california_housing()  #导入房屋数据
    

    print(housing.data.shape)

    from sklearn import tree
    dtr
    = tree.DecisionTreeRegressor(max_depth=2) #设置数的限制条件
    dtr.fit(housing.data[:, [6,7]], housing.target) #以data数据的第6和第7个特征,以及目标函数建立模型

    # 图像可视化
    dot_data =
    tree.export_graphviz(
    dtr,
    out_file
    = None,
    feature_names
    = housing.feature_names[6:8],
    filled
    = True,
    impurity
    = False,
    rounded
    = True,
    )

    import pydotplus
    graph
    = pydotplus.graph_from_dot_data(dot_data)
    graph.get_nodes()[
    7].set_fillcolor("#FFF2DD") #上色
    from IPython.display import Image
    Image(graph.create_png())
    graph.write_png(
    "dtr_white_background.png") #图片保存为png格式
    plt.show()

    复制代码

    现在采用所有的变量构建参数,但是速度较慢

    复制代码
    from sklearn.model_selection import train_test_split
    

    data_train, data_test, target_train, target_test = train_test_split(housing.data, housing.target,
    test_size
    =0.1, random_state=42) #test_size 表示分割的百分比, random_state表示随机种子,每次随机的结果不变
    # 分割数据
    dtr = tree.DecisionTreeRegressor(random_state=42) #构建树
    dtr.fit(data_train, target_train) #建立模型
    dtr.score(data_train, target_train) #模型得分,得分越高模型效果越好

    #构建决策树图像
    dot_data =
    tree.export_graphviz(
    dtr,
    #树的名称
    out_file = None,
    feature_names
    = housing.feature_names,
    filled
    = True,
    impurity
    = False,
    rounded
    = True,
    )
    import pydotplus
    graph
    = pydotplus.graph_from_dot_data(dot_data)
    graph.get_nodes()[
    7].set_fillcolor("#FFF2DD")
    from IPython.display import Image
    Image(graph.create_png())
    graph.write_png(
    "dtr_white_background_1.png") #png分辨率较高
    plt.show()

    复制代码
  • 相关阅读:
    UPC-5930 Rest Stops(水题)
    UPC-6199 LCYZ的道路(贪心)
    UPC-6198 JL的智力大冲浪(简单贪心)
    POJ 3279 Filptile dfs
    hrbust 1621 迷宫问题II 广搜
    HDU 1045 dfs + 回溯
    优先队列基本用法
    树。森林。和二叉树之间的转换
    POJ 2689 筛法求素数
    哈理工OJ 1328
  • 原文地址:https://www.cnblogs.com/litieshuai/p/11388294.html
Copyright © 2011-2022 走看看