zoukankan      html  css  js  c++  java
  • 使用交叉验证对鸢尾花分类模型进行调参(超参数)

    如何选择超参数:

    交叉验证:

    如图,

    1. 大训练集分块,使用不同的分块方法分成N对小训练集验证集
    2. 使用小训练集进行训练,使用验证集进行验证,得到准确率,求N个验证集上的平均正确率
    3. 使用平均正确率最高的超参数,对整个大训练集进行训练,训练出参数。
    4. 训练集上训练。
     
     
     
    十折交叉验证

    网格搜索

    诸如你有多个可调节的超参数,那么选择超参数的方法通常是网格搜索,即固定一个参、变化其他参,像网格一样去搜索。


     
     
     
     
    # 人工智能数据源下载地址:https://video.mugglecode.com/data_ai.zip,下载压缩包后解压即可(数据源与上节课相同)
    # -*- coding: utf-8 -*-
    
    """
        任务:鸢尾花识别
    """
    import pandas as pd
    from sklearn.model_selection import train_test_split, GridSearchCV
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    
    
    DATA_FILE = './data_ai/Iris.csv'
    
    SPECIES_LABEL_DICT = {
        'Iris-setosa':      0,  # 山鸢尾
        'Iris-versicolor':  1,  # 变色鸢尾
        'Iris-virginica':   2   # 维吉尼亚鸢尾
    }
    
    # 使用的特征列
    FEAT_COLS = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
    
    
    def main():
        """
            主函数
        """
        # 读取数据集
        iris_data = pd.read_csv(DATA_FILE, index_col='Id')
        iris_data['Label'] = iris_data['Species'].map(SPECIES_LABEL_DICT)
    
        # 获取数据集特征
        X = iris_data[FEAT_COLS].values
    
        # 获取数据标签
        y = iris_data['Label'].values
    
        # 划分数据集
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3, random_state=10)
    
        model_dict = {'kNN':
                          (
                              KNeighborsClassifier(),
                              {'n_neighbors': [5, 15, 25], 'p': [1, 2]}
                           ),
                      'Logistic Regression':
                          (
                              LogisticRegression(),
                              {'C': [1e-2, 1, 1e2]}
                          ),
                      'SVM':
                          (
                              SVC(),
                              {'C': [1e-2, 1, 1e2]}
                          )
                      }   # 名称+元组
    
        for model_name, (model, model_params) in model_dict.items():
            # 训练模型
            clf = GridSearchCV(estimator=model, param_grid=model_params, cv=5) #模型、参数、折数
            clf.fit(X_train, y_train)   #训练
            best_model = clf.best_estimator_   #最佳模型的对象
    
            # 验证
            acc = best_model.score(X_test, y_test)
            print('{}模型的预测准确率:{:.2f}%'.format(model_name, acc * 100))
            print('{}模型的最优参数:{}'.format(model_name, clf.best_params_))       #最好的模型名称和参数
    
    
    if __name__ == '__main__':
        main()

    运行结果:

    kNN模型的预测准确率:96.00%
    kNN模型的最优参数:{'n_neighbors': 15, 'p': 2}
    Logistic Regression模型的预测准确率:96.00%
    Logistic Regression模型的最优参数:{'C': 100.0}
    SVM模型的预测准确率:98.00%
    SVM模型的最优参数:{'C': 1}
    

    练习

    练习:使用交叉验证对水果分类模型进行调参

    • 题目描述:为模型选择最优的参数并进行水果类型识别,模型包括kNN,逻辑回归及SVM。对应的超参数为:

    • kNN中的近邻个数n_neighbors及闵式距离的p值

    • 逻辑回归的正则项系数C值

    • SVM的正则项系数C值

    • 题目要求:

    • 使用3折交叉验证对模型进行调参

    • 使用scikit-learn提供的方法为模型调参

    • 数据文件:

    • 数据源下载地址:https://video.mugglecode.com/fruit_data.csv(数据源与上节课相同)

    • fruit_data.csv,包含了59个水果的的数据样本。

    • 共5列数据

    • fruit_name:水果类别

    • mass: 水果质量

    • 水果的宽度

    • height: 水果的高度

    • color_score: 水果的颜色数值,范围0-1。

    • 0.85 - 1.00:红色

    • 0.75 - 0.85: 橙色

    • 0.65 - 0.75: 黄色

    • 0.45 - 0.65: 绿色


       
      image

    可能的代码

    import pandas as pd
    from sklearn.model_selection import GridSearchCV, train_test_split
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    
    #读取数据
    data = pd.read_csv('./data_ai/fruit_data.csv')
    
    #数据处理
    fruit_dict = {
        'apple':    0,
        'lemon':    1,
        'mandarin': 2,
        'orange':   3
    }
    
    data['label'] = data['fruit_name'].map(fruit_dict)
    
    feat_cols = ['mass','width','height','color_score']
    
    #数据提取
    X = data[feat_cols].values
    y = data['label'].values
    
    X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=1/5, random_state= 3)
    
    model_dict = {
        'KNN': ( KNeighborsClassifier(), {'n_neighbors': [5,15,25], 'p' : [1,2]} ),
        'Logestic Regression': (LogisticRegression(), {'C':[1e02, 1, 1e2] }),
        'SVM': (SVC(), {'C':[1e02, 1, 1e2]})
    }
    
    for model_name, (model, model_para) in model_dict.items():
        #训练
        clf = GridSearchCV(estimator=model, param_grid=model_para, cv=5)  # 模型、参数、折数
        clf.fit(X_train,y_train)
        best_model = clf.best_estimator_
    
        #验证
        acc = best_model.score(X_test, y_test)
        print(f'{model_name}中选择{clf.best_params_}为参数的预测准确率最好,准确率可达{acc*100}%')

    运行结果:

    KNN中选择{'n_neighbors': 5, 'p': 1}为参数的预测准确率最好,准确率可达66.66666666666666%
    Logestic Regression中选择{'C': 100.0}为参数的预测准确率最好,准确率可达91.66666666666666%
    SVM中选择{'C': 100.0}为参数的预测准确率最好,准确率可达50.0%



    作者:夏威夷的芒果
    链接:https://www.jianshu.com/p/790ac622dc18
    來源:简书

  • 相关阅读:
    docker的网络服务
    想真正了解JAVA设计模式看着一篇就够了。 详解+代码实例
    再问你Java内存模型的时候别再给我讲堆栈方法区
    ssh爆破脚本
    ecshop3.0.0注入
    zabbix 安装配置以及漏洞检测脚本
    代理爬取
    selenium2使用记录
    初级AD域渗透系列
    用ftplib爆破FTP口令
  • 原文地址:https://www.cnblogs.com/caiyishuai/p/13270902.html
Copyright © 2011-2022 走看看