zoukankan      html  css  js  c++  java
  • 学习进度笔记

    学习进度笔记09

    TensorFlow K近邻算法

    import numpy as np  

    import tensorflow as tf  

    from tensorflow.examples.tutorials.mnist import input_data  

    import os  

    os.environ["CUDA_VISIBLE_DEVICES"]="0"  

    mnist =input_data.read_data_sets("/home/yxcx/tf_data/MNIST_data",one_hot=True)  

    Xtr,Ytr=mnist.train.next_batch(5000)  

    Xte,Yte=mnist.test.next_batch(200)  

    #tf Graph Input  

    xtr=tf.placeholder("float",[None,784])  

    xte=tf.placeholder("float",[784])  

    distance =tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)  

    pred=tf.argmin(distance,0)  

      

    accuracy=0  

    init=tf.global_variables_initializer()  

    #Start training  

    with tf.Session() as sess:  

        sess.run(init)  

        for i in range(len(Xte)):  

            #Get nearest nerighbor  

            nn_index=sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})  

            print("Test",i ,"Prediction:",np.argmax(Ytr[nn_index]),"True Class:",np.argmax(Yte[i]))  

            if np.argmax(Ytr[nn_index])==np.argmax(Yte[i]):  

                accuracy+=1./len(Xte)  

        print("Done!")  

        print("accuacy:" ,accuracy)  

  • 相关阅读:
    接口性能测试方案
    如何选择自动化测试框架
    一维和二维前缀和
    高精度 加减乘除
    归并排序 快速排序
    链表
    二分查找
    表达式求值
    c++ const问题小记
    虚继承总结
  • 原文地址:https://www.cnblogs.com/xueqiuxiang/p/14466976.html
Copyright © 2011-2022 走看看