zoukankan      html  css  js  c++  java
  • 决策树

    决策树算法:主要通过信息熵或者gini系数来作为衡量标准

    当完成决策树时需要进行剪枝操作,在剪枝过程中,我们一般采用预剪枝的操作(该操作更加实用)

    预剪枝过程中的几个限制条件:

                                               1. 限制深度

                                               2. 叶子节点个数

                                               3.叶子节点样本数

                                               4.信息增益量

                                               ..... 

    下面以一个房屋数据为列子

    复制代码
    from sklearn.datasets.california_housing import  fetch_california_housing
    housing = fetch_california_housing()  #导入房屋数据
    

    print(housing.data.shape)

    from sklearn import tree
    dtr
    = tree.DecisionTreeRegressor(max_depth=2) #设置数的限制条件
    dtr.fit(housing.data[:, [6,7]], housing.target) #以data数据的第6和第7个特征,以及目标函数建立模型

    # 图像可视化
    dot_data =
    tree.export_graphviz(
    dtr,
    out_file
    = None,
    feature_names
    = housing.feature_names[6:8],
    filled
    = True,
    impurity
    = False,
    rounded
    = True,
    )

    import pydotplus
    graph
    = pydotplus.graph_from_dot_data(dot_data)
    graph.get_nodes()[
    7].set_fillcolor("#FFF2DD") #上色
    from IPython.display import Image
    Image(graph.create_png())
    graph.write_png(
    "dtr_white_background.png") #图片保存为png格式
    plt.show()

    复制代码

    现在采用所有的变量构建参数,但是速度较慢

    复制代码
    from sklearn.model_selection import train_test_split
    

    data_train, data_test, target_train, target_test = train_test_split(housing.data, housing.target,
    test_size
    =0.1, random_state=42) #test_size 表示分割的百分比, random_state表示随机种子,每次随机的结果不变
    # 分割数据
    dtr = tree.DecisionTreeRegressor(random_state=42) #构建树
    dtr.fit(data_train, target_train) #建立模型
    dtr.score(data_train, target_train) #模型得分,得分越高模型效果越好

    #构建决策树图像
    dot_data =
    tree.export_graphviz(
    dtr,
    #树的名称
    out_file = None,
    feature_names
    = housing.feature_names,
    filled
    = True,
    impurity
    = False,
    rounded
    = True,
    )
    import pydotplus
    graph
    = pydotplus.graph_from_dot_data(dot_data)
    graph.get_nodes()[
    7].set_fillcolor("#FFF2DD")
    from IPython.display import Image
    Image(graph.create_png())
    graph.write_png(
    "dtr_white_background_1.png") #png分辨率较高
    plt.show()

    复制代码
  • 相关阅读:
    第十二周进度表
    第一个冲刺周期-第十天
    第一个冲刺周期-第九天
    团队作业—第二阶段06
    团队作业—第二阶段05
    团队作业—第二阶段04
    团队作业—第二阶段03
    团队作业—第二阶段02
    团队作业—第二阶段01
    对于风行小组第一阶段冲刺成果的概括
  • 原文地址:https://www.cnblogs.com/litieshuai/p/11388294.html
Copyright © 2011-2022 走看看