zoukankan      html  css  js  c++  java
  • 使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码

    # -*- coding: utf-8 -*-
    import numpy as np
    from sklearn.feature_extraction import FeatureHasher
    from sklearn import datasets
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.neighbors import KNeighborsClassifier
    import xgboost as xgb
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import train_test_split
    from sklearn import metrics
    from matplotlib import pyplot as plt
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.model_selection import GridSearchCV
    
    def report(test_Y, pred_Y):
        print("accuracy_score:")
        print(metrics.accuracy_score(test_Y, pred_Y))
        print("f1_score:")
        print(metrics.f1_score(test_Y, pred_Y))
        print("recall_score:")
        print(metrics.recall_score(test_Y, pred_Y))
        print("precision_score:")
        print(metrics.precision_score(test_Y, pred_Y))
        print("confusion_matrix:")
        print(metrics.confusion_matrix(test_Y, pred_Y))
        print("AUC:")
        print(metrics.roc_auc_score(test_Y, pred_Y))
    
        f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y)
        auc_area = metrics.auc(f_pos, t_pos)
        plt.plot(f_pos, t_pos, 'darkorange', lw=2, label='AUC = %.2f' % auc_area)
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
        plt.title('ROC')
        plt.ylabel('True Pos Rate')
        plt.xlabel('False Pos Rate')
        plt.show()
    
    
    
    if __name__== '__main__':
        x, y = datasets.make_classification(n_samples=1000, n_features=100,n_redundant=0, random_state = 1)
        train_X, test_X, train_Y, test_Y = train_test_split(x,
                                                            y,
                                                            test_size=0.2,
                                                            random_state=66)
        #clf = GradientBoostingClassifier(n_estimators=100)
        #clf.fit(train_X, train_Y)
        #pred_Y = clf.predict(test_X)
        #report(test_Y, pred_Y)
        scoring= "f1"
        parameters ={'n_estimators': range( 50, 200, 25), 'max_depth': range( 2, 10, 2)}
        gsearch = GridSearchCV(estimator= GradientBoostingClassifier(), param_grid= parameters, scoring='accuracy', iid= False, cv= 5) 
        gsearch.fit(x, y)
        print("gsearch.best_params_") 
        print(gsearch.best_params_) 
        print("gsearch.best_score_") 
        print(gsearch.best_score_)
    

     效果:

    gsearch.best_params_
    {'max_depth': 4, 'n_estimators': 100}
    gsearch.best_score_
    0.868142228555714

  • 相关阅读:
    字串变换
    重建道路
    poj3278 Catch That Cow
    机器人搬重物
    [HNOI2004]打鼹鼠
    曼哈顿距离
    邮票面值设计
    poj1101 The Game
    解决了一个堆破坏问题
    模型资源从无到有一条龙式体验
  • 原文地址:https://www.cnblogs.com/bonelee/p/9154171.html
Copyright © 2011-2022 走看看