zoukankan      html  css  js  c++  java
  • 学习进度笔记

    学习进度笔记09

    TensorFlow K近邻算法

    import numpy as np  

    import tensorflow as tf  

    from tensorflow.examples.tutorials.mnist import input_data  

    import os  

    os.environ["CUDA_VISIBLE_DEVICES"]="0"  

    mnist =input_data.read_data_sets("/home/yxcx/tf_data/MNIST_data",one_hot=True)  

    Xtr,Ytr=mnist.train.next_batch(5000)  

    Xte,Yte=mnist.test.next_batch(200)  

    #tf Graph Input  

    xtr=tf.placeholder("float",[None,784])  

    xte=tf.placeholder("float",[784])  

    distance =tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)  

    pred=tf.argmin(distance,0)  

      

    accuracy=0  

    init=tf.global_variables_initializer()  

    #Start training  

    with tf.Session() as sess:  

        sess.run(init)  

        for i in range(len(Xte)):  

            #Get nearest nerighbor  

            nn_index=sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})  

            print("Test",i ,"Prediction:",np.argmax(Ytr[nn_index]),"True Class:",np.argmax(Yte[i]))  

            if np.argmax(Ytr[nn_index])==np.argmax(Yte[i]):  

                accuracy+=1./len(Xte)  

        print("Done!")  

        print("accuacy:" ,accuracy)  

  • 相关阅读:
    php实现rpc简单的方法
    统计代码量
    laravel的速查表
    header的参数不能带下划线
    PHP简单实现单点登录功能示例
    phpStorm函数注释的设置
    数据结构基础
    laravel生命周期和核心思想
    深入理解php底层:php生命周期
    Jmeter:实例(性能测试目标)
  • 原文地址:https://www.cnblogs.com/xueqiuxiang/p/14466976.html
Copyright © 2011-2022 走看看