zoukankan      html  css  js  c++  java
  • TensorFlow K近邻算法(基于MNIST数据集)

    knn的基本原理:

    KNN是通过计算不同特征值之间的距离进行分类。

    整体的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在分类决策上只依据最邻近的一个或者几个样本的类别来决定待分类样本所属的类别。

    代码:

    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import os
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    
    #导入数据
    MINIST_data=r'D:mnist' #数据集存放位置
    mnist=input_data.read_data_sets(MINIST_data,one_hot=True)
    #设置训练集与测试集的batch大小
    Xtr,Ytr=mnist.train.next_batch(5000)
    Xte,Yte=mnist.test.next_batch(200)
    #构造计算图,使用占位符placeholder函数构造变量xtr,xte,代码如下:
    xtr=tf.placeholder("float",[None,784])
    xte=tf.placeholder("float",[784])
    #求数据之间的距离,并取最小的值
    #tf.negative() 计算其负数  tf.abs() 求其绝对值 tf.argmin()返回矩阵横列或者纵列的最小值的坐标,取决于第二个参数 0为纵列 1为横列
    #曼哈顿距离
    distance =tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)
    #欧式距离
    #distance = tf.sqrt(tf.reduce_sum(tf.square(tf.add(xtr, tf.negative(xte))), reduction_indices=1))
    pred=tf.argmin(distance,0)
    #初始化全部变量
    init=tf.global_variables_initializer()
    #使用tf.Session()创建Session会话对象,会话封装了Tensorflow运行时的状态和控制
    sess=tf.Session()
    sess.run(init)
    #训练模型,并用测试数据预测其准确率
    accuracy=0#计算准确率
    error=0 #错误个数
    for i in range(len(Xte)):
        nn_index=sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})
        print("Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i]))
        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
            accuracy += 1. / len(Xte)
        else:
            error+=1
    print("完成!")
    print("准确分类:",len(Xte)-error)
    print("错误分类:",error)
    print("准确率:",accuracy)

    结果:

  • 相关阅读:
    OCX控件的注册卸载,以及判断是否注册
    SimpleJdbcTemplate批量更新(BeanPropertySqlParameterSource)
    hibernateTemplate封装jdbc的一个简单思路
    Dao层查询
    ==与equals方法的区别(Java基础)
    中文乱码解决办法
    spring核心配置文件_ActiveMQ消息队列配置
    spring核心配置文件_Elasticsearch搜索配置
    spring核心配置文件_数据库连接信息
    spring核心配置文件_数据库连接信息_数据库信息
  • 原文地址:https://www.cnblogs.com/1061321925wu/p/12613363.html
Copyright © 2011-2022 走看看