zoukankan      html  css  js  c++  java
  • 10.寻找最好的超参数

    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.metrics import accuracy_score

    1、获取数据

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target

    2、分割数据,得到训练集和测试集

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)

    3、手动寻找

    # def temp():
        # knn_clf = KNeighborsClassifier(3)
        # knn_clf.fit(X_train, y_train)
        # y_predict = knn_clf.predict(X_test)
        # accuracy_score(y_test, y_predict)
    
        # # 寻找最好的k
        # best_score = 0.0
        # best_k = -1
        # for k in range(1,11):
        #     knn_clf = KNeighborsClassifier(k)
        #     knn_clf.fit(X_train, y_train)
        #     y_predict = knn_clf.predict(X_test)
        #     score= accuracy_score(y_test, y_predict)
        #     if score > best_score:
        #         best_k = k
        #         best_score = score
        # print("best_k:", best_k)
        # print("best_score:", best_score)
    
        # # 考虑距离?不考虑距离?
        # best_method = ""
        # best_score = 0.0
        # best_k = -1
        # for method in ["uniform", "distance"]:
        #     for k in range(1,11):
        #         knn_clf = KNeighborsClassifier(n_neighbors=k, weights=method)
        #         knn_clf.fit(X_train, y_train)
        #         y_predict = knn_clf.predict(X_test)
        #         score= accuracy_score(y_test, y_predict)
        #         if score > best_score:
        #             best_k = k
        #             best_score = score
        #             best_method = method
        # print("best_k:", best_k)
        # print("best_score:", best_score)
        # print("best_method:", best_method)
        # # # 探索明可夫斯基距离相应的p
    
        # # 寻找最好的超参数 Grid Search

    3、超参数配置

    param_grid = [
        {
            "weights":["uniform"],
            "n_neighbors":[i for i in range(1,11)]
        },
        {
            "weights":["distance"],
            "n_neighbors":[i for i in range(1,11)],
            "p":[i for i in range(1,6)]
        }]

    4、实例化分类器

    knn_clf = KNeighborsClassifier()

    5、为分类器和超参数搭建模型

    from sklearn.model_selection import GridSearchCV
    grid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)

    6、实例化模型(多种参数配置的分类器)fit训练集

    # 本质上是将训练集进一步分为训练集和测试集,得到最好的参数配置
    # 因为要不断尝试各种参数交叉验证,所以非常耗时

    grid_search.fit(X_train, y_train)

    7、最终拿到最佳参数配置分类器 best_estimator_

    knn_clf = grid_search.best_estimator_

    8、使用最佳分类器对测试集预测

    y_predict = knn_clf.predict(X_test)

    9、打印准确率

    print(accuracy_score(y_test, y_predict))
  • 相关阅读:
    [React]核心概念
    [算法]复杂度分析
    [算法]移除指定元素&strSr()的实现
    [算法]合并链表&删除数组重复项
    php _weakup()反序列化漏洞
    Java 注解详解
    MyBatis入门
    Spring 事务管理
    Spring AOP
    Spring JDBC
  • 原文地址:https://www.cnblogs.com/waterr/p/14039226.html
Copyright © 2011-2022 走看看