zoukankan      html  css  js  c++  java
  • sklearn获得某个参数的不同取值在训练集和测试集上的表现的曲线刻画

    from sklearn.svm import SVC
    from sklearn.datasets import make_classification
    import numpy as np
    
    X,y = make_classification()
    
    
    def plot_validation_curve(estimator,X,y,param_name="gamma",
                              param_range=np.logspace(-6,-1,5),cv=5,scoring="accuracy"):
        """
        描述:获得某个参数的不同取值在训练集和测试集上的表现
        """
        from sklearn.model_selection import validation_curve
        import matplotlib.pyplot as plt
        
        train_scores,test_scores = validation_curve(estimator=estimator, 
                                                    X=X, 
                                                    y=y, 
                                                    cv=cv,
                                                    scoring=scoring,
                                                    param_name=param_name,
                                                    param_range=param_range)
        
        train_scores_mean = np.mean(train_scores, axis=1)
        train_scores_std  = np.std(train_scores, axis=1)
        test_scores_mean  = np.mean(test_scores, axis=1)
        test_scores_std   = np.std(test_scores, axis=1)
        
        plt.title("Validation Curve")
        plt.xlabel("$gamma$")
        plt.ylabel("Score")
        plt.ylim(0.0, 1.1)
        
        plt.semilogx(param_range,train_scores_mean,label="Training score",color="darkorange", lw=2)
        plt.fill_between(param_range,
                         train_scores_mean-train_scores_std,
                         train_scores_mean+train_scores_std,
                         alpha=0.2,
                         color="darkorange", 
                         lw=2)
        
        plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",color="navy", lw=2)    
        plt.fill_between(param_range, 
                         test_scores_mean - test_scores_std,
                         test_scores_mean + test_scores_std, 
                         alpha=0.2,
                         color="navy", 
                         lw=2)
        
        plt.legend(loc="best")
        plt.show()
        
    
        
    plot_validation_curve(estimator=SVC(),
                          X=X,y=y,
                          param_name="gamma",
                          param_range=np.logspace(-6,-1,5),cv=5,scoring="accuracy")    
        
  • 相关阅读:
    自己做一个无敌的文件粉碎机
    编程王道,唯“慢”不破
    在Flex4中嵌入字体
    java函数参数默认值
    Adobe Air移动开发本人体会
    安装VS2013,可是电脑C盘没空间了,今天早上整理了下
    SilverFoxServer出炉!!
    C#中Abstract和Virtual
    解决insert语句插入时,需要写列值的问题
    SQL 标量函数-----日期函数 day() 、month()、year()
  • 原文地址:https://www.cnblogs.com/wzdLY/p/9886270.html
Copyright © 2011-2022 走看看