zoukankan      html  css  js  c++  java
  • 使用train_test_split、GridSearchCV 对load_wine 进行KNN分类

    原创转载请注明出处:https://www.cnblogs.com/agilestyle/p/12678281.html

    准备数据

    import numpy as np
    from sklearn.datasets import load_wine
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import train_test_split
    from sklearn.neighbors import KNeighborsClassifier
    
    samples = load_wine()
    
    samples
    
    # sklearn.utils.Bunch
    type(samples)
    
    # dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
    samples.keys()
    
    print(samples['DESCR'])
    
    # (178, 13)
    samples['data'].shape
    
    # (178,)
    samples['target'].shape

    分割数据,将20%的数据作为测试集,其余作为训练集

    X_train, X_test, y_train, y_test = train_test_split(samples['data'], samples['target'], test_size=0.2, random_state=2)
    
    # (142, 13)
    X_train.shape
    
    # (36, 13)
    X_test.shape
    
    # (142,)
    y_train.shape
    
    # (36,)
    y_test.shape

    建模训练

    # KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
    #                     metric_params=None, n_jobs=None, n_neighbors=5, p=2,
    #                     weights='uniform')
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train, y_train)

    评价模型

    score = knn_clf.score(X_test, y_test)
    
    # 0.75
    score

    网格搜索+交叉验证

    Note: A parameter grid is a dictionary with the parameter setting you would like to try.

    param_grid = {
        'n_neighbors': [2, 3, 4, 5, 6, 7, 8],
        'p': [1, 2],
        'weights': ['uniform', 'distance']
    }
    
    grid_search = GridSearchCV(knn_clf, param_grid=param_grid, cv=10)
    grid_search.fit(X_train, y_train)
    
    # 0.8333333333333334
    grid_search.score(X_test, y_test)
    
    # {'n_neighbors': 2, 'p': 1, 'weights': 'distance'}
    grid_search.best_params_
    
    new_wine = np.random.randint(1, 100, (1, 13))
    # (1, 13)
    new_wine.shape
    
    pred_result = grid_search.predict(new_wine)
    # array([1])
    pred_result
    
    # array(['class_1'], dtype='<U7')
    samples['target_names'][pred_result]

    Reference

    https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html

    https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

  • 相关阅读:
    CPU使用率终极计算
    elementui
    spring security oauth2
    maven bom
    vue jsx
    [spring cloud] feign声明
    加分项
    JAVA日报
    JAVA日报
    JAVA日报
  • 原文地址:https://www.cnblogs.com/agilestyle/p/12678281.html
Copyright © 2011-2022 走看看