zoukankan      html  css  js  c++  java
  • 使用网格搜索优化模型参数

    1.简单网格搜索法

    • Lasso算法中不同的参数调整次数

    #############################  使用网格搜索优化模型参数 #######################################
    #导入套索回归模型
    from sklearn.linear_model import Lasso
    #导入数据集拆分工具
    from sklearn.model_selection import train_test_split
    #导入红酒数据集
    from sklearn.datasets import load_wine
    #载入红酒数据集
    wine = load_wine()
    #将数据集差分为训练集与测试集
    X_train,X_test,y_train,y_test = train_test_split(wine.data,wine.target,random_state=38)
    #设置初始分数为0
    best_score = 0
    #设置alpha参数遍历0.01,0.1,1,10
    for alpha in [0.01,0.1,1.0,10.0]:
        #最大迭代数遍历100,1000,5000,10000
        for max_iter in [100,1000,5000,10000]:
            lasso = Lasso(alpha=alpha,max_iter=max_iter)
            #训练套索回归模型
            lasso.fit(X_train,y_train)
            score = lasso.score(X_test,y_test)
            #令最佳分数为所有分数中的最高值
            if score > best_score:
                best_score = score
                #定义字典,返回最佳参数和最佳最大迭代数
                best_parameters = {'alpha':alpha,'最大迭代次数':max_iter}
    
    #打印结果
    print('模型最高分为:{:.3f}'.format(best_score))
    print('最佳参数设置:{}'.format(best_parameters))
    
    模型最高分为:0.889
    最佳参数设置:{'alpha': 0.01, '最大迭代次数': 100}
    #将数据集差分为训练集与测试集
    X_train,X_test,y_train,y_test = train_test_split(wine.data,wine.target,random_state=0)
    #设置初始分数为0
    best_score = 0
    #设置alpha参数遍历0.01,0.1,1,10
    for alpha in [0.01,0.1,1.0,10.0]:
        #最大迭代数遍历100,1000,5000,10000
        for max_iter in [100,1000,5000,10000]:
            lasso = Lasso(alpha=alpha,max_iter=max_iter)
            #训练套索回归模型
            lasso.fit(X_train,y_train)
            score = lasso.score(X_test,y_test)
            #令最佳分数为所有分数中的最高值
            if score > best_score:
                best_score = score
                #定义字典,返回最佳参数和最佳最大迭代数
                best_parameters = {'alpha':alpha,'最大迭代次数':max_iter}
    
    #打印结果
    print('模型最高分为:{:.3f}'.format(best_score))
    print('最佳参数设置:{}'.format(best_parameters))
    
    模型最高分为:0.830
    最佳参数设置:{'alpha': 0.1, '最大迭代次数': 100}

    2.与交叉验证结合的网格搜索

    #导入numpy
    import numpy as np
    #导入交叉验证工具
    from sklearn.model_selection import cross_val_score
    #设置alpha参数遍历0.01,0.1,1,10
    for alpha in [0.01,0.1,1.0,10.0]:
        #最大迭代数遍历100,1000,5000,10000
        for max_iter in [100,1000,5000,10000]:
            lasso = Lasso(alpha=alpha,max_iter=max_iter)
            #训练套索回归模型
            lasso.fit(X_train,y_train)
            scores = cross_val_score(lasso,X_train,y_train,cv=6)
            score = np.mean(scores)
            if score > best_score:
                best_score = score
                #定义字典,返回最佳参数和最佳最大迭代数
                best_parameters = {'alpha':alpha,'最大迭代次数':max_iter}
    
    #打印结果
    print('模型最高分为:{:.3f}'.format(best_score))
    print('最佳参数设置:{}'.format(best_parameters))
    
    模型最高分为:0.865
    最佳参数设置:{'alpha': 0.01, '最大迭代次数': 100}
    #用最佳参数模型拟合数据
    lasso = Lasso(alpha=0.01,max_iter=100).fit(X_train,y_train)
    #打印测试数据集得分
    print('测试数据集得分:{:.3f}'.format(lasso.score(X_test,y_test)))
    
    测试数据集得分:0.819
    #导入网格搜索工具
    from sklearn.model_selection import GridSearchCV
    #将需要遍历的参数定义为字典
    params = {'alpha':[0.01,0.1,1.0,10.0],'max_iter':[100,1000,5000,10000]}
    #定义网格搜索中使用的模型和参数
    grid_search = GridSearchCV(lasso,params,cv=6,iid=False)
    #使用网格搜索模型拟合数据
    grid_search.fit(X_train,y_train)
    #打印结果
    print('模型最高分:{:.3f}'.format(grid_search.score(X_test,y_test)))
    print('最有参数:{}'.format(grid_search.best_params_))
    
    模型最高分:0.819
    最有参数:{'alpha': 0.01, 'max_iter': 100}
    #打印网格搜索中的best_score_属性
    print('交叉验证最高得分:{:.3f}'.format(grid_search.best_score_))
    
    交叉验证最高得分:0.865

    总结 :  

    • GridSearchCV本身就是将交叉验证和网格搜索封装在一起的方法.
    • GridSearchCV虽然是个非常强悍的功能,但是由于需要反复建模,因此所需要的计算时间更长.

    文章引自 : 《深入浅出python机器学习》

  • 相关阅读:
    Angular 的性能优化
    通往架构师之路的三本书,高分!
    从单体架构到微服务架构的演化历程
    nginx 配置stream模块代理并开启日志配置
    UnixLinux 执行 shell 报错:“$' ': 未找到命令” 的解决办法
    红胖子(红模仿)的博文大全:开发技术集合大版本更新v4.0.0
    案例分享:Qt高频fpga采集数据压力位移速度加速度分析系统(通道配置、电压转换、采样频率、通道补偿、定时采集、距离采集,导出exce、自动XY轴、隐藏XY轴、隐藏显示通道,文件回放等等)
    字符编码和字符集到底有什么区别?Unicode和UTF-8是什么关系?
    Linux从头学15:【页目录和页表】-理论 + 实例 + 图文的最完全、最接地气详解
    【分页机制】-看了这篇文章还没彻底搞懂?我自罚三杯!
  • 原文地址:https://www.cnblogs.com/weijiazheng/p/10966005.html
Copyright © 2011-2022 走看看