zoukankan      html  css  js  c++  java
  • 使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码

    # -*- coding: utf-8 -*-
    import numpy as np
    from sklearn.feature_extraction import FeatureHasher
    from sklearn import datasets
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.neighbors import KNeighborsClassifier
    import xgboost as xgb
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import train_test_split
    from sklearn import metrics
    from matplotlib import pyplot as plt
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.model_selection import GridSearchCV
    
    def report(test_Y, pred_Y):
        print("accuracy_score:")
        print(metrics.accuracy_score(test_Y, pred_Y))
        print("f1_score:")
        print(metrics.f1_score(test_Y, pred_Y))
        print("recall_score:")
        print(metrics.recall_score(test_Y, pred_Y))
        print("precision_score:")
        print(metrics.precision_score(test_Y, pred_Y))
        print("confusion_matrix:")
        print(metrics.confusion_matrix(test_Y, pred_Y))
        print("AUC:")
        print(metrics.roc_auc_score(test_Y, pred_Y))
    
        f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y)
        auc_area = metrics.auc(f_pos, t_pos)
        plt.plot(f_pos, t_pos, 'darkorange', lw=2, label='AUC = %.2f' % auc_area)
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
        plt.title('ROC')
        plt.ylabel('True Pos Rate')
        plt.xlabel('False Pos Rate')
        plt.show()
    
    
    
    if __name__== '__main__':
        x, y = datasets.make_classification(n_samples=1000, n_features=100,n_redundant=0, random_state = 1)
        train_X, test_X, train_Y, test_Y = train_test_split(x,
                                                            y,
                                                            test_size=0.2,
                                                            random_state=66)
        #clf = GradientBoostingClassifier(n_estimators=100)
        #clf.fit(train_X, train_Y)
        #pred_Y = clf.predict(test_X)
        #report(test_Y, pred_Y)
        scoring= "f1"
        parameters ={'n_estimators': range( 50, 200, 25), 'max_depth': range( 2, 10, 2)}
        gsearch = GridSearchCV(estimator= GradientBoostingClassifier(), param_grid= parameters, scoring='accuracy', iid= False, cv= 5) 
        gsearch.fit(x, y)
        print("gsearch.best_params_") 
        print(gsearch.best_params_) 
        print("gsearch.best_score_") 
        print(gsearch.best_score_)
    

     效果:

    gsearch.best_params_
    {'max_depth': 4, 'n_estimators': 100}
    gsearch.best_score_
    0.868142228555714

  • 相关阅读:
    centos7环境下安装mysql5.6-----解压安装包的方法
    Linux的常用命令
    在同一个类中,一个方法调用另外一个有注解(比如@Async,@Transational)的方法,注解失效的原因和解决方法
    springboot下实现邮件发送功能
    centos7环境下开启指定端口
    阿里云开放指定端口
    Nginx的alias的用法及与root的区别
    关于Springboot打包错误的问题 | Failed to execute goal org.springframework.boot:spring-boot-maven-plugin
    怎么简单高效破解MyEclipse10、获取注册码
    git删除远程分支
  • 原文地址:https://www.cnblogs.com/bonelee/p/9154171.html
Copyright © 2011-2022 走看看