zoukankan      html  css  js  c++  java
  • 基于sklearn K临近算法 最简单预测 花的种类

    因为注释已经很详细了,所以直接上代码:

     1 from sklearn.datasets import load_iris
     2 from sklearn.model_selection import train_test_split
     3 #k临近算法
     4 from sklearn.neighbors import KNeighborsClassifier
     5 import numpy as np
     6 import pandas as pd
     7 def get数据():
     8     iris_dataset=load_iris()
     9     print("keys:
    {}".format(iris_dataset.keys()))
    10     print("预测花的品种
    :{}".format(iris_dataset['target_names']))
    11     '''['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename']'''
    12 
    13     '''机器学习中的个体叫作样本(sample),其属性叫作特征(feature)。data数组的形状(shape)是样本数乘以特征数。这是scikit-learn中的约定,你的数据形状应始终遵循这个约定'''
    14 
    15     ''':['setosa' 'versicolor' 'virginica']'''
    16     print(iris_dataset['target'])
    17 
    18 if __name__ =='__main__':
    19     '''数据'''
    20     iris_dataset=load_iris()
    21     X_train,X_test,y_train,y_test=train_test_split(
    22         iris_dataset['data'],iris_dataset['target'],random_state=0)
    23     print("X_train shapr:",X_train.shape)
    24 
    25 
    26     #运用pandas将Numpy数组转成pandas DataFrame
    27     #iris_dataframe=pd.DataFrame(X_train,columns=iris_dataset.feature_names)
    28     #绘制散点图
    29     #grr=pd.scatter_matrix(iris_dataframe,c=y_train,figsize=(15,15),marker='o',hist_kwds={'bins':20},s=60,alpha=8,cmap=mglearn.cm3)
    30 
    31     '''K临近算法 ,其中只包括训练集 '''
    32     knn=KNeighborsClassifier(n_neighbors=1)
    33     '''构建模型'''
    34     knn.fit(X_train,y_train)
    35     '''做出预测  1,4'''
    36     X_new=np.array([[5,2.9,1,0.2]])
    37     prediction=knn.predict(X_new)
    38     print("预测结果:
    {}".format(prediction))
    39     print("预测名字:
    {}".format(iris_dataset['target_names'][prediction]))
    40 
    41     '''计算精度(评估模型 对比)'''
    42     y_pred=knn.predict(X_test)
    43     index=np.mean(y_pred==y_test)
    44     print("相似度:{:.2f}".format(index))
    45     '''knn对象计算'''
    46     indexa=knn.score(X_test,y_test)
    47     print(indexa)
  • 相关阅读:
    [root@192 ~]# ls /etc/sysconfig/network-scripts
    解决unknown import path "golang.org/x/sys/unix": unrecognized import path "golang.org/x/sys"
    Centos 修改IP地址、网关、DNS
    Centos7 下安装golang
    yum国内镜像配置
    grep -R --include=*.log warning /var/log
    第五章 单例模式(待续)
    第四章 工厂模式(待续)
    第三章 装饰者模式(待续)
    第二章 观察者模式(待续)
  • 原文地址:https://www.cnblogs.com/smartisn/p/12556195.html
Copyright © 2011-2022 走看看