zoukankan      html  css  js  c++  java
  • K-NN回归算法

    from sklearn.datasets import load_iris
    import numpy  as np
    import matplotlib.pyplot as plt
           
    iris = load_iris()
    iris_data = iris.data
    iris_target = iris.target
    print(iris.feature_names)
    
    X = iris_data[:,0:2]
    y = iris_data[:,3]
    #['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
    
    #We'll try to predict the petal length based on the sepal length and width.
    #We'll also fit a regular linear regression to see how well the k-NN regression does in comparison
    
    #线性回归
    from sklearn.linear_model import LinearRegression
    lr = LinearRegression()
    lr.fit(X, y)
    print ("The MSE is: {:.2}".format(np.power(y - lr.predict(X),2).mean()))
    
    
    
    #K-NN 回归
    from sklearn.neighbors import KNeighborsRegressor
    knnr = KNeighborsRegressor(n_neighbors=10)
    knnr.fit(X, y)
    print ("The MSE is: {:.2}".format(np.power(y - knnr.predict(X),2).mean()))
    
    #仅仅显示预测函数如何使用而已
    print(knnr.predict(np.array([3.0,5.0]).reshape(1,-1)))
    
    #Let's look at what the k-NN regression does when we tell it to use the closest 10 points for regression:
    f, ax = plt.subplots(nrows=2, figsize=(7, 10))
    ax[0].set_title("Predictions")
    ax[0].scatter(X[:, 0], X[:, 1], s=lr.predict(X)*80, label='LRPredictions', color='c', edgecolors='black')
    ax[1].scatter(X[:, 0], X[:, 1], s=knnr.predict(X)*80, label='k-NNPredictions', color='m', edgecolors='black')
    ax[0].legend()
    ax[1].legend()
    f.show()
    
    #针对某一个类别(KNN的效果优于线性)
    setosa_idx = np.where(iris.target_names=='setosa')
    setosa_mask = (iris.target == setosa_idx[0])
    print(y[setosa_mask][:20])
    print(knnr.predict(X)[setosa_mask][:20])
    print(lr.predict(X)[setosa_mask][:20])
    
    #针对某一个具体的点
    #The k-NN regression is very simply calculated taking the average of the k closest point to the point being tested.
    #Let's manually predict a single point:
    example_point = X[0]
    '''
    原始真值
    >>> X[0]
    array([ 5.1,  3.5])
    >>> y[0]
    0.20000000000000001
    '''
    
    from sklearn.metrics import pairwise
    distances_to_example = pairwise.pairwise_distances(X)[0]  #X[0]和其它150个元素(包括自己)的距离 
    ten_closest_points = X[np.argsort(distances_to_example)][:10] #排序后,寻找10个距离最小的索引
    ten_closest_y = y[np.argsort(distances_to_example)][:10]#所这些最下的10个已知数找出来
    print(ten_closest_y.mean())
    
    #We can see that this is very close to what was expected.

    ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
    The MSE is: 0.15
    The MSE is: 0.069
    [ 0.2]
    [ 0.2  0.2  0.2  0.2  0.2  0.4  0.3  0.2  0.2  0.1  0.2  0.2  0.1  0.1  0.2
      0.4  0.4  0.3  0.3  0.3]
    [ 0.28  0.17  0.21  0.2   0.31  0.27  0.21  0.31  0.19  0.17  0.29  0.28
      0.17  0.19  0.26  0.27  0.27  0.28  0.27  0.31]
    [ 0.44636645  0.53893889  0.29846368  0.27338255  0.32612885  0.47403161
      0.13064785  0.42128532  0.22322028  0.49136065  0.56918808  0.27596658
      0.46627952  0.10298268  0.71709085  0.45411854  0.47403161  0.44636645
      0.73958795  0.30363175]
    0.28

  • 相关阅读:
    MyBatis+Oracle
    JAVA接口,json传递
    Oracle学习笔记(二)
    Oracle学习笔记(一)
    数据库事务四大特性之隔离性
    数据库事务四大特性(ACID)
    多表连接时条件放在 on 与 where 后面的区别
    tomcat request.getParamter() 乱码解决方案 Filter版本
    POI excel下载 中文名 浏览器兼容解决
    天马行空
  • 原文地址:https://www.cnblogs.com/qqhfeng/p/5338180.html
Copyright © 2011-2022 走看看