zoukankan      html  css  js  c++  java
  • TensorFlow(八) TensorFlow图像识别(KNN)

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn import  datasets
    import random
    from PIL import Image
    
    from tensorflow.examples.tutorials.mnist import  input_data
    
    sess=tf.Session()
    mnist= input_data.read_data_sets("MNIST_data/",one_hot=True)
    #本例包含10个类别
    train_size=1000
    test_size=102
    rand_train_indices=np.random.choice(len(mnist.train.images),train_size,replace=False)
    
    rand_test_indices=np.random.choice(len(mnist.train.images),test_size,replace=False)
    
    x_vals_train=mnist.train.images[rand_train_indices]
    x_vals_test=mnist.train.images[rand_test_indices]
    y_vals_train=mnist.train.labels[rand_train_indices]
    y_vals_test=mnist.train.labels[rand_test_indices]
    
    k=4
    batch_size=6
    x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32)
    x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32)
    y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32)
    y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)
    
    #L1距离 shape=(6, 1000)   sub.shape=(1000,784) - (6,1,10)=(6,1000,784)
    distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2)
    
    #top K (6, 4)
    top_k_xvals,top_k_indices=tf.nn.top_k(tf.negative(distance),k=k)
    #(6, 4, 10)  =   gather((1000,10),(6,4)  )
    prediction_indices=tf.gather(y_target_train,top_k_indices)
    #shape=(6, 10)
    count_of_prediction=tf.reduce_sum(prediction_indices,reduction_indices=1)
    #预测模型 shape=(6,)
    prediction=tf.arg_max(count_of_prediction,dimension=1)
    
    num_loop=int(np.ceil(len(x_vals_test)/batch_size))
    test_output=[]
    actual_vals=[]
    for i in range(num_loop):
        min_index=i*batch_size
        max_index=min((i+1)*batch_size,len(x_vals_test))
        #获取数据
        x_batch=x_vals_test[min_index:max_index]
        y_batch = y_vals_test[min_index:max_index]
        predictions=sess.run(prediction,feed_dict={x_data_test:x_batch,x_data_train:x_vals_train,y_target_test:y_batch,y_target_train:y_vals_train})
        test_output.extend(predictions)
        actual_vals.extend(np.argmax(y_batch,axis=1))
    
    #精确度预测
    accuracy=sum( 1./test_size for i in range(test_size) if test_output[i]==actual_vals[i])
    print("Accuarcy: "+str(accuracy))
    
    actuals=np.argmax(y_batch,axis=1)
    for i in range(len(actuals)):
        plt.subplot(2,3,i+1)
        plt.imshow(np.reshape(x_batch[i],[28,28]),cmap="Greys_r")
        plt.title('Actual: '+str(actuals[i])+' Pred:'+str(predictions[i]),fontsize=10)
        frame=plt.gca()
        frame.axes.get_xaxis().set_visible(False)
        frame.axes.get_yaxis().set_visible(False)
    
    plt.show()

  • 相关阅读:
    vue 进行ajax请求,使用axios
    webpack 小记
    js闭包
    git小记
    sublime text3 插件安装
    mysql安装
    四十九、django单表操作,多表操作,一对多,多对多,分组查询,聚合查询,F、Q查询,自定义char字段,事务,only与defer ,choices字段
    四十八、模板层,标签,过滤器,继承和导入
    四十七、django路由匹配,分组,反向解析,路由分发,视图层
    四十六、django配置及增删改查
  • 原文地址:https://www.cnblogs.com/x0216u/p/9241759.html
Copyright © 2011-2022 走看看