zoukankan      html  css  js  c++  java
  • TFboy养成记 MNIST Classification (主要是如何计算accuracy)

     参考:莫烦。

    主要是运用的MLP。另外这里用到的是批训练:

    这个代码很简单,跟上次的基本没有什么区别。

    这里的lossfunction用到的是是交叉熵cross_entropy.可能网上很多形式跟这里的并不一样。

    这里一段时间会另开一个栏。专门去写一些机器学习上的一些理论知识。

    这里代码主要写一下如何计算accuracy:

    1 def getAccuracy(v_xs,v_ys):
    2     global y_pre
    3     y_v = sess.run(y_pre,feed_dict={x:v_xs})
    4     correct_prediction = tf.equal(tf.arg_max(y_v,1),tf.arg_max(v_ys,1))
    5     accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    6     result = sess.run(accuracy,feed_dict={x:v_xs,y:v_ys})
    7     
    8     return result

    首先得到ground truth,与预测值,然后对着预测值得到tf,arg_max---->你得到的是以float tensor,tensor上的各个值是各个分类结果的可能性,而argmax函数就是求里面的最大值的下表也就是结果。

    注意这里每次得到的是一个batch的结果,也就是说以一个【9,1,2,、。。。。】的这种tensor,所以最后用tf.equal得到一个表示分类值与实际类标是否相同的Bool型tensor。最后把tensor映射到0,1,两个值上就可以了.

    可能会有人问为什么不用int表示而是用float32来表示呢?因为下面腰酸的是准确率,如果是int32,那么按tensorflow的整数除法运算是直接取整数部分不算小数点的。(这几个涉及到的函数在之前的博客)

    全部代码:

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Sun Jun 18 15:31:11 2017
     4 
     5 @author: Jarvis
     6 """
     7 
     8 import tensorflow as tf
     9 import numpy as np
    10 from tensorflow.examples.tutorials.mnist import input_data
    11 
    12 def addlayer(inputs,insize,outsize,activate_func = None):
    13     W = tf.Variable(tf.random_normal([insize,outsize]),tf.float32)
    14     b = tf.Variable(tf.zeros([1,outsize]),tf.float32)
    15     W_plus_b = tf.matmul(inputs,W)+b
    16 
    17     if activate_func == None:
    18         return W_plus_b
    19     else:
    20         return activate_func(W_plus_b)
    21 def getAccuracy(v_xs,v_ys):
    22     global y_pre
    23     y_v = sess.run(y_pre,feed_dict={x:v_xs})
    24     correct_prediction = tf.equal(tf.arg_max(y_v,1),tf.arg_max(v_ys,1))
    25     
    26     accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    27     result = sess.run(accuracy,feed_dict={x:v_xs,y:v_ys})
    28     
    29     return result
    30 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    31 
    32 x  = tf.placeholder(tf.float32,[None,784])
    33 y = tf.placeholder(tf.float32,[None,10])
    34 #h1 = addlayer(x,784,14*14,activate_func=tf.nn.softmax)
    35 #y_pre = addlayer(h1,14*14,10,activate_func=tf.nn.softmax)
    36 y_pre = addlayer(x,784,10,activate_func=tf.nn.softmax)
    37 
    38 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pre),reduction_indices=[1]))
    39 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    40 
    41 sess = tf.Session()
    42 sess.run(tf.global_variables_initializer())
    43 for i in range(10001):
    44     x_batch,y_batch = mnist.train.next_batch(100)
    45     sess.run(train_step,feed_dict={x:x_batch,y:y_batch})
    46     
    47     if i % 100 == 0:
    48         print (getAccuracy(mnist.test.images,mnist.test.labels))
    49     
    View Code
  • 相关阅读:
    Java8 Stream Function
    PLINQ (C#/.Net 4.5.1) vs Stream (JDK/Java 8) Performance
    罗素 尊重 《事实》
    小品 《研发的一天》
    Java8 λ表达式 stream group by max then Option then PlainObject
    这人好像一条狗啊。什么是共识?
    TOGAF TheOpenGroup引领开发厂商中立的开放技术标准和认证
    OpenMP vs. MPI
    BPMN2 online draw tools 在线作图工具
    DecisionCamp 2019, Decision Manager, AI, and the Future
  • 原文地址:https://www.cnblogs.com/silence-tommy/p/7045850.html
Copyright © 2011-2022 走看看