zoukankan      html  css  js  c++  java
  • 交叉验证

    交叉验证

    以鸢尾花数据集为例

    from sklearn.datasets import load_iris
    iris = load_iris()
    data = iris.data
    target = iris.target
    # 交叉验证 把数据集分成 不同的训练集 和 测试集 然后多次测算模型的准确率
    # cross_val cross validate 交叉验证
    from sklearn.model_selection import cross_val_score
    from sklearn.neighbors import KNeighborsClassifier
    knn = KNeighborsClassifier()
    # estimator, X, y
    # estimator 要测算的模型
    # X,y 数据集的 特征值 和 目标值
    # cv 指的是交叉验证的次数 现在默认是3 将来可能会变成5次
    cross_val_score(knn,data,target,cv=10)

    得到的是一个有10个元素的一维数组

    cross_val_score(knn,data,target,cv=10).mean() # 0.966

    网格搜索

    网格搜索是针对参数使用不同的参数来看哪一个参数的情况下 模型的效果更好

    # GridSearchCV
    # Grid网格 Search搜索 CV交叉验证
    from sklearn.model_selection import GridSearchCV
    KNeighborsClassifier()  # n_neighbors
    
    # estimator, param_grid,
    # estimator要测试的模型
    # param_grid参数的网格 (传入要测试的不同的参数)
    
    param_grid = {
        # 'n_neighbors': range(6,10)
        'n_neighbors': [6,7,8,9], # 取邻近点的个数k。k取1-9测试
        'weights': ['uniform','distance'],# uniform:一致的权重;distance:距离的倒数作为权重
        'p':(1,2)  # 列表和元组都可以 p=1欧式距离 p=2曼哈顿距离
    }
    
    grid = GridSearchCV(knn,param_grid,cv=3,iid=True)  # 获得一个等待训练的空模型

    # grid 必须先训练 训练好之后 才能获取各种 best...

    grid.fit(data,target)
    grid.score(data,target)
    grid.best_estimator_
    grid.best_params_

    集成学习

    随机森林就是一种集成学习的方式 只不过随机森林中所有的模型都是决策树
    我们这里可以使用一种投票的方式 可以结合各种不同的 模型来使用

    from sklearn.ensemble import VotingClassifier
    # 投票分类器
    # VotingClassifier
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.naive_bayes import GaussianNB
    from sklearn.svm import SVC
    # 注意模型数量是奇数个 防止投出平票
    
    knn = KNeighborsClassifier()
    lgc = LogisticRegression()
    dtree = DecisionTreeClassifier()
    gnb = GaussianNB()
    svc = SVC()
    
    # estimators  传入各个要使用的模型
    # list of (string, estimator) tuples  由好多元组组成的列表
    estimators = []
    estimators.append(('knn',knn))  # 以 名字 和 模型对象 构成的一个元组
    estimators.append(('lgc',lgc))
    estimators.append(('dtree',dtree))
    estimators.append(('nb',gnb))
    estimators.append(('svc',svc))
    
    voting = VotingClassifier(estimators)

    切分训练集与测试集并训练上面获取到的模型

    from sklearn.model_selection import train_test_split
    X_train,X_test,y_train,y_test = train_test_split(data,target)
    voting.fit(X_train,y_train)
    
    voting.score(X_test,y_test)  # 每一个测试样本 使用不同的5个模型 都得到一个结果 看哪一个分类结果 得票多
    # 投票的方式 可能不会把分数提高很多
    # 但是 如果数据没问题的话 至少能保证 准确率不会低得离谱
  • 相关阅读:
    索引法则--少用OR,它在连接时会索引失效
    索引法则--LIKE以%开头会导致索引失效进而转向全表扫描(使用覆盖索引解决)
    索引法则--字符串不加单引号会导致索引失效
    索引法则--IS NULL, IS NOT NULL 也无法使用索引
    tomcat管理模块报401 Unauthorized
    MySQL报Too many connections
    JDBC连接MySql例子
    linux安装jdk并设置环境变量(看这一篇文章即可)
    深度解析Java可变参数类型以及与数组的区别
    MySQL真正的UTF-8字符集utf8mb4
  • 原文地址:https://www.cnblogs.com/louyifei0824/p/10006141.html
Copyright © 2011-2022 走看看