zoukankan      html  css  js  c++  java
  • TensorFlow K近邻算法

    TensorFlow K近邻算法

    实验目的

    1.掌握使用TensorFlow进行KNN操作

    2.掌握KNN 算法的原理

    实验原理

    knn的基本原理:

    KNN是通过计算不同特征值之间的距离进行分类。

    整体的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在分类决策上只依据最邻近的一个或者几个样本的类别来决定待分类样本所属的类别。

    KNN算法要解决的核心问题是K值选择,它会直接影响分类结果。如果选择较大的K值,就相当于用较大领域中的训练实例进行预测,其优点是可以减少学习的估计误差,但缺点是学习的近似误差会增大。如果选择较小的K值,就相当于用较小的领域中的训练实例进行预测,“学习”近似误差会减小,只有与输入实例较近或相似的训练实例才会对预测结果起作用,与此同时带来的问题是“学习”的估计误差会增大,换句话说,K值的减小就意味着整体模型变得复杂,容易发生过拟合;

    使用tensorflow进行KNN算法的整体过程是先设计计算图,然后运行会话,执行计算图的过程,整个过程的数据可见性比较差。以上精确度的计算以及真实标签和预测标签的比较结果其实使用numpy和python的变量。

    实验环境

    Windows10

    Python 3.6.0

    Pycharm

    TensorFlow

    实验内容

    使用TensorFlow进行K近邻算法的操作。

    实验步骤

    1.打开Pycharm,右键选择New=>Python File,创建名为tf_KNN的Python文件。

    2.打开tf_KNN.py文件,编写KNN代码。

    导入实验所需要的模块

    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data

    3.导入实验所需的数据

    mnist = input_data.read_data_sets("C:/Users/32016/Desktop/spark/深度学习算法部分/",one_hot = True)

    4.设置训练集与测试集的batch大小

    Xtr,Ytr=mnist.train.next_batch(5000)
    Xte,Yte=mnist.test.next_batch(200)

    5.构造计算图,使用占位符placeholder函数构造变量xtr,xte,代码如下:

    xtr=tf.placeholder("float",[None,784])
    xte=tf.placeholder("float",[784])

    6.求数据之间的距离,并取最小的值。

    distance = tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)
    pred=tf.argmin(distance,0)

    7.初始化全部变量

    accuracy=0
    init=tf.global_variables_initializer()

    8.使用tf.Session()创建Session会话对象,会话封装了Tensorflow运行时的状态和控制。

    with tf.Session() as sess:
        sess.run(init)

    9.训练模型,并用测试数据预测其准备率。

    for i in range(len(Xte)):
        nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})
        print("Test", i , "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i]))
        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
            accuracy += 1./len(Xte)
        print("Done!")
        print("accuacy:" ,accuracy)

    10.完整代码:

    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #导入实验所需的数据
    mnist = input_data.read_data_sets("C:/Users/32016/Desktop/spark/深度学习算法部分/",one_hot = True)
    #设置训练集与测试集的batch大小
    Xtr,Ytr=mnist.train.next_batch(5000)
    Xte,Yte=mnist.test.next_batch(200)
    #构造计算图,使用占位符placeholder函数构造变量xtr,xte
    xtr=tf.placeholder("float",[None,784])
    xte=tf.placeholder("float",[784])
    #求数据之间的距离,并取最小的值
    distance = tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)
    pred=tf.argmin(distance,0)
    #初始化全部变量
    accuracy=0
    init=tf.global_variables_initializer()
    #使用tf.Session()创建Session会话对象,会话封装了Tensorflow运行时的状态和控制
    with tf.Session() as sess:
        sess.run(init)
        #训练模型,并用测试数据预测其准备率
        for i in range(len(Xte)):
            nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})
            print("Test", i , "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i]))
            if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
                accuracy += 1./len(Xte)
            print("Done!")
            print("accuacy:" ,accuracy)

    11.运行结果为:

     

     

  • 相关阅读:
    福州KTV
    MSN登陆不上:微软谴责中国的“技术问题”
    DB2 存储过程开发最佳实践
    在DB2存储过程中返回一个数据集
    Host is not allowed to connect to this MySQL server 解决方案
    CentOS安装中文支持
    ImportError: libpq.so.5: cannot open shared object file: No such file or directory
    CentOS 终端显示中文异常解决办法
    pytestDemo
    python 获取当前运行的类名函数名
  • 原文地址:https://www.cnblogs.com/1gaoyu/p/12595643.html
Copyright © 2011-2022 走看看