zoukankan      html  css  js  c++  java
  • 随机森林分类算法

    随机森林(Random Forest,简称RF)是通过集成学习的思想将多棵树集成的一种算法,它的基本单元是决策树。假设现在针对的是分类问题,每棵决策树都是一个分类器,那么N棵树会有N个分类结果。随机森林集成了所有的分类投票结果,将投票次数最多的类别指定为最终输出。它可以很方便的并行训练。
    森林表示决策树是多个。随机表现为两个方面:数据的随机性化、待选特征的随机化。
     
    构建流程:采取有放回的抽样方式构造子数据集,保证不同子集之间的数量级一样(元素可以重复);利用子数据集来构建子决策树;将待预测数据放到每个子决策树中,每个子决策树输出一个结果;统计子决策树的投票结果,投票数多的就是随机森林的输出结果。
    (1)从样本集中用 Bootstrap采样选出一定数量的样本,比如80%样本集;
    (2)从所有属性中随机选择K个属性,在K个属性中再选择出最佳分割属性作为节点创建决策树;
    (3)重复以上两步m次,即建立m棵决策树。可以并行:即m个样本同时提取,m棵决策树同时生成;
    (4)这m个决策树形成随机森林,通过投票表决结果(比如少数服从多数)决定待预测数据的结果。
     
    随机森林和决策树在单个决策树上的构建区别是:所有特征变成随机部分特征。部分数量是K个,K的取值有一定的讲究。太小了使得单棵树的精度太低,太大了树之间的相关性会加强,独立性会减弱。通常取总特征数的平方根。
    PS1:有的文章介绍随机森林的使用时,对特征并没有随机部分选取,还是同决策树全部选取的。
    PS2:现实情况下,一个数据集往往有成百个特征,如何选择对结果影响较大的那几个特征,以此来缩减建立模型时的特征数。可以参考:http://www.sohu.com/a/297967370_729271 五,特征重要性评估。
     
     
    代码示例
    from sklearn.datasets import load_iris
    from sklearn.ensemble import RandomForestClassifier
    import pandas as pd
    import numpy as np
     
    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df['is_train'] = np.random.uniform(0, 1, len(df)) <= 0.80
    df['species'] = iris.target
    # 得到训练集和测试集
    train, test = df[df['is_train']==True], df[df['is_train']==False]
    # 定义特征列
    features = df.columns[:4]
     
    # 训练模型,限制树的最大深度3,决策树个数10
    clf = RandomForestClassifier(max_depth=3, n_estimators=10, max_features=1)
    # 获得训练集中真实结果
    Y, _ = pd.factorize(train['species'])
    # 开始训练
    clf.fit(train[features], Y)
     
    # 通过测试集进行预测
    preds = clf.predict(test[features])
    print (preds)
    print (test['species'].values)
     
    # 显示预测准确率
    diff = 0
    for num in range(0,len(preds)):
       if(preds[num] != test['species'].values[num]):
           diff = diff + 1
    rate = ((diff+0.0) / len(preds))
    print (1.0-rate)
     
     
     
  • 相关阅读:
    CentOS6.5配置MySQL主从同步
    CentOS6.5安装telnet
    linux 下安装Google Chrome (ubuntu 12.04)
    jdk w7环境变量配置
    JDBCConnectionException: could not execute query,数据库连接池问题
    注意开发软件的版本问题!
    linux mysql命令行导入导出.sql文件 (ubuntu 12.04)
    linux 下root用户和user用户的相互切换 (ubuntu 12.04)
    linux 下 vim 的使用 (ubuntu 12.04)
    linux 下安装配置tomcat-7 (ubuntu 12.04)
  • 原文地址:https://www.cnblogs.com/myshuzhimei/p/11746846.html
Copyright © 2011-2022 走看看