zoukankan      html  css  js  c++  java
  • sklearn调参(验证曲线,可视化不同参数下交叉验证得分)

     一 、 原始方法:

    思路:

    1. 参数从 0+∞ 的一个 区间 取点, 方法如: np.logspace(-10, 0, 10) , np.logspace(-6, -1, 5)

    2. 循环调用cross_val_score计算得分。

    在SVM不同的惩罚参数C下的模型准确率。

    import matplotlib.pyplot as plt  
    from sklearn.model_selection import cross_val_score  
    import numpy as np  
    from sklearn import datasets, svm  
    digits = datasets.load_digits()  
    x = digits.data  
    y = digits.target  
    vsc = svm.SVC(kernel='linear')  
    
    if __name__=='__main__':  
        c_S = np.logspace(-10, 0, 10)#在范围内取是个对数  
        # print ("length", len(c_S))  
        scores = list()  
        scores_std = list()  
        for c in c_S:  
            vsc.C = c  
            this_scores = cross_val_score(vsc, x, y, n_jobs=4)#多线程 n_jobs,默认三次交叉验证  
            scores.append(np.mean(this_scores))  
            scores_std.append(np.std(this_scores))  
        plt.figure(1, figsize=(4, 3))#绘图  
        plt.clf()  
        plt.semilogx(c_S, scores)#划线  
        plt.semilogx(c_S, np.array(scores)+np.array(scores_std), 'b--')  
        plt.semilogx(c_S, np.array(scores)-np.array(scores_std), 'b--')  
        locs, labels = plt.yticks()  
        plt.yticks(locs, list(map(lambda X: "%g" % X, locs)))#阶段点  
        plt.ylabel('CV score')  
        plt.xlabel('parameter C')  
        plt.ylim(0, 1.1)#范围  
        plt.show()  

    效果:

    二、高级方法(validation_curve)

    思路:

    直接用validation_curve获得模型在不同参数下,每次训练得分和测试得分。

    
    

    from sklearn import svm
    from sklearn.model_selection import validation_curve
    from sklearn.datasets import load_digits
    import numpy as np
    import matplotlib.pyplot as plt
    digits = load_digits()
    X = digits.data
    y = digits.target
    param_range = np.logspace(-6, -1, 5)
    vsc = svm.SVC()
    train_score, test_score = validation_curve(vsc, X, y, param_name='gamma', param_range=param_range, cv=10, scoring="accuracy", n_jobs=1)
    train_score_mean = np.mean(train_score, axis=1)
    train_score_std = np.std(train_score, axis=1)
    test_score_mean = np.mean(test_score, axis=1)
    test_score_std = np.std(test_score, axis=1)
    plt.title("validation curve with SVM")
    plt.xlabel("$gamma%")
    plt.ylabel("Score")
    plt.ylim()
    lw = 2
    plt.semilogx(param_range, train_score_mean,label="training score", color="darkorange", lw=lw)
    plt.fill_between(param_range, train_score_mean-train_score_std, train_score_mean+train_score_std, alpha=0.2, color="navy", lw=lw)

    
    

    plt.semilogx(param_range, test_score_mean,label="test score", color="blue", lw=lw)
    plt.fill_between(param_range, test_score_mean-test_score_std, test_score_mean+test_score_std, alpha=0.2, color="navy", lw=lw)

    
    

    plt.legend(loc="best")
    plt.show()

     

    结果:

  • 相关阅读:
    mysql source 乱码
    php5.6.11 openssl安装
    python threading模块的Lock和RLock区别
    python多线程一些知识点梳理
    多核处理器中进程和线程是如何一起工作的?
    IO是否会一直占用CPU?(转)
    Python TypeError: __init__() got multiple values for argument 'master'(转)
    Jquery中.bind()、.live()、.delegate()和.on()之间的区别详解(转)
    浏览器环境下JavaScript脚本加载与执行探析之代码执行顺序(转)
    为什么有的网页右击没有出现审查元素
  • 原文地址:https://www.cnblogs.com/andylhc/p/10431201.html
Copyright © 2011-2022 走看看