zoukankan      html  css  js  c++  java
  • 参数优化-验证曲线

    通过验证一个学习器在训练集和测试集上的表现,来确定模型是否合适,参数是否合适。

    如果训练集和测试集得分都很低,说明学习器不合适。

    如果训练集得分高,测试集得分低,模型过拟合,训练集得分低,测试集得分高,不太可能。

    示例代码

    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.datasets import load_digits
    from sklearn.model_selection import validation_curve
    from sklearn.svm import SVC
    
    # 加载数据
    digits = load_digits()
    X, y = digits.data, digits.target
    
    # 验证曲线
    param_range = np.logspace(-6, -1, 5)
    train_scores, test_scores = validation_curve(
        SVC(), X, y, param_name="gamma", param_range=param_range,
        cv=10, scoring="accuracy", n_jobs=1)
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)
    
    plt.title("SVM VC")
    plt.xlabel("$gamma$")
    plt.ylabel("Score")
    plt.ylim(0.0, 1.1)
    
    # 训练数据
    plt.semilogx(param_range, train_scores_mean, label="train score", color="r")
    plt.fill_between(param_range, train_scores_mean - train_scores_std,
                     train_scores_mean + train_scores_std, alpha=0.2, color="r")
    
    # 测试数据
    plt.semilogx(param_range, test_scores_mean, label="test score",color="g")
    plt.fill_between(param_range, test_scores_mean - test_scores_std,
                     test_scores_mean + test_scores_std, alpha=0.2, color="g")
    plt.legend(loc="best")
    plt.show()

    输出

    参数gamma的调节

    很小时,训练集和测试集得分都低,欠拟合

    增大时,训练集和测试集得分有个很好地值

    过大时,训练集得分高,测试集得分低,过拟合。

  • 相关阅读:
    监听器、过滤器
    最详细的Log4j使用教程
    Tomcat version 7.0 only supports J2EE 1.2, 1.3, 1.4, and Java EE 5 and 6 Web modules
    Unsupported major.minor version 52.0
    jdk安装
    数据库建表
    SpringMVC学习系列-后记 解决GET请求时中文乱码的问题
    面向对象中的常用魔术方法
    面向对象中的修饰关键词
    面向对象三大特性之二--【继承】
  • 原文地址:https://www.cnblogs.com/yanshw/p/10688553.html
Copyright © 2011-2022 走看看