zoukankan      html  css  js  c++  java
  • sklearn实现kNN

    对鸢尾花数据集进行分类并交叉验证

    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.model_selection import GridSearchCV
    def kNN_iris_gscv():
        """
        用kNN对鸢尾花进行分类,添加网格搜索和交叉验证
        :return:
     """
        #1.获取数据
        iris=load_iris()
        #2.划分数据集
        x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=1)
        #3.特征工程:标准化
        transfer=StandardScaler()
        x_train=transfer.fit_transform(x_train)
        x_test=transfer.transform(x_test) #使用训练集的平均值和标准差
        #4.模型训练
        estimator=KNeighborsClassifier()
        #加入网格搜索和交叉验证
        #参数准备
        param_dict={"n_neighbors":[1,3,5,7,9,11]}
        estimator=GridSearchCV(estimator,param_grid=param_dict,cv=10) #对estimator预估器进行10折交叉验证
        estimator.fit(x_train,y_train) #模型拟合
        #5.模型评估
        #方法1:比对真实值和预测值
        y_predict=estimator.predict(x_test)
        print(y_predict)
        print("直接比对真实值和预测值:
    ",y_predict==y_test)
        #方法2:直接计算准确率
        score=estimator.score(x_test,y_test)
        print("准确率为:",score)
    
        #最佳参数:best_params
        print("最佳参数:
    ",estimator.best_params_)
        #最佳结果:best_score_
        print("最佳结果:
    ", estimator.best_score_)
        #最佳估计器:best_estimator_
        print("最佳估计器:
    ", estimator.best_estimator_)
        #交叉验证结果:cv_results_
        print("交叉验证结果:
    ", estimator.cv_results_)
        return None
    
    
    if __name__=="__main__":
        kNN_iris_gscv()
  • 相关阅读:
    环境变量
    多重继承
    参数检查(@property)
    限制属性绑定(__slots__)
    实例属性和类属性
    2017-11-28 中文编程语言之Z语言初尝试: ZLOGO 4
    2017-10-23 在各种编程语言中使用中文命名
    2017-11-27 汉化了十数个编译器的前辈的心得体会
    五行
    阴阳
  • 原文地址:https://www.cnblogs.com/sclu/p/11759730.html
Copyright © 2011-2022 走看看