zoukankan      html  css  js  c++  java
  • 学习进度笔记

    学习进度笔记09

    TensorFlow K近邻算法

    import numpy as np  

    import tensorflow as tf  

    from tensorflow.examples.tutorials.mnist import input_data  

    import os  

    os.environ["CUDA_VISIBLE_DEVICES"]="0"  

    mnist =input_data.read_data_sets("/home/yxcx/tf_data/MNIST_data",one_hot=True)  

    Xtr,Ytr=mnist.train.next_batch(5000)  

    Xte,Yte=mnist.test.next_batch(200)  

    #tf Graph Input  

    xtr=tf.placeholder("float",[None,784])  

    xte=tf.placeholder("float",[784])  

    distance =tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)  

    pred=tf.argmin(distance,0)  

      

    accuracy=0  

    init=tf.global_variables_initializer()  

    #Start training  

    with tf.Session() as sess:  

        sess.run(init)  

        for i in range(len(Xte)):  

            #Get nearest nerighbor  

            nn_index=sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})  

            print("Test",i ,"Prediction:",np.argmax(Ytr[nn_index]),"True Class:",np.argmax(Yte[i]))  

            if np.argmax(Ytr[nn_index])==np.argmax(Yte[i]):  

                accuracy+=1./len(Xte)  

        print("Done!")  

        print("accuacy:" ,accuracy)  

  • 相关阅读:
    运维ipvsadm配置负载均衡
    vue--存储
    vue--v-model表单控件绑定
    CSS--外发光与内阴影
    vue运行报错--preventDefault
    vue运行报错--dependency
    vue--移动端兼容问题
    vue--生命周期
    vue--vuex
    Vue--vux组件库
  • 原文地址:https://www.cnblogs.com/xueqiuxiang/p/14466976.html
Copyright © 2011-2022 走看看