zoukankan      html  css  js  c++  java
  • tensorflow学习3---mnist

     1 import tensorflow as tf 
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 '''数据下载'''
     5 mnist=input_data.read_data_sets('Mnist_data',one_hot=True)
     6 #one_hot标签
     7       
     8 '''生成层 函数'''
     9 def add_layer(input,in_size,out_size,n_layer='layer',activation_function=None):
    10     layer_name='layer %s' % n_layer
    11     '''补充知识'''
    12     #tf.name_scope:Wrapper for Graph.name_scope() using the default graph.
    13     #scope名字的作用域
    14     #sprase:A string (not ending with '/') will create a new name scope, in which name is appended to the prefix of all operations created in the context. 
    15     #If name has been used before, it will be made unique by calling self.unique_name(name).
    16     with tf.name_scope('weights'):
    17         Weights=tf.Variable(tf.random_normal([in_size,out_size]),name='w')
    18         tf.summary.histogram(layer_name+'/wights',Weights)
    19         #tf.summary.histogram:output summary with histogram直方图
    20         #tf,random_normal正太分布
    21     with tf.name_scope('biases'):
    22         biases=tf.Variable(tf.zeros([1,out_size])+0.1)
    23         tf.summary.histogram(layer_name+'/biases',biases)
    24         #tf.summary.histogram:k
    25     with tf.name_scope('Wx_plus_b'):
    26         Wx_plus_b=tf.matmul(input,Weights)+biases
    27     if activation_function==None:
    28         outputs=Wx_plus_b
    29     else:
    30         outputs=activation_function(Wx_plus_b)
    31     tf.summary.histogram(layer_name+'/output',outputs)
    32     return outputs
    33 '''准确率'''
    34 def compute_accuracy(v_xs,v_ys):
    35     global prediction
    36     y_pre=sess.run(prediction,feed_dict={xs:v_xs})#<
    37     #tf.equal()对比预测值的索引和实际label的索引是否一样,一样返回True,否则返回false
    38     correct_prediction=tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
    39     #correct_prediction-->[ True False  True ...,  True  True  True]
    40     '''补充知识-tf.argmax'''
    41     #tf.argmax:Returns the index with the largest value across dimensions of a tensor.
    42     #tf.argmax()----->
    43     accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    44     #正确cast为1,错误cast为0
    45     '''补充知识 tf.cast'''
    46     #tf.cast:   Casts a tensor to a new type.
    47     ## tensor `a` is [1.8, 2.2], dtype=tf.float
    48     #tf.cast(a, tf.int32) ==> [1, 2]  # dtype=tf.int32
    49     result=sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys})
    50     #print(sess.run(correct_prediction,feed_dict={xs:v_xs,ys:v_ys}))
    51     #ckc=tf.cast(correct_prediction,tf.float32)
    52     #print(sess.run(ckc,feed_dict={xs:v_xs,ys:v_ys}))
    53     return result
    54 
    55 
    56 '''占位符'''
    57 xs=tf.placeholder(tf.float32,[None,784])
    58 ys=tf.placeholder(tf.float32,[None,10])
    59 
    60 '''添加层'''
    61 
    62 prediction=add_layer(xs,784,10,activation_function=tf.nn.softmax)
    63 #sotmax激活函数,用于分类函数
    64 
    65 '''计算'''
    66 #交叉熵cross_entropy损失函数,参数分别为实际的预测值和实际的label值y,re
    67 '''补充知识'''
    68 #reduce_mean()
    69 # 'x' is [[1., 1. ]]
    70 #         [2., 2.]]
    71 #tf.reduce_mean(x) ==> 1.5
    72 #tf.reduce_mean(x, 0) ==> [1.5, 1.5]
    73 #tf.reduce_mean(x, 1) ==> [1.,  2.]
    74 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1]))
    75 '''补充知识'''
    76 #reduce_sum
    77 # 'x' is [[1, 1, 1]]
    78 #         [1, 1, 1]]
    79 #tf.reduce_sum(x) ==> 6
    80 #tf.reduce_sum(x, 0) ==> [2, 2, 2]
    81 #tf.reduce_sum(x, 1) ==> [3, 3]
    82 #tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
    83 #tf.reduce_sum(x, [0, 1]) ==> 6
    84 
    85 train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    86 
    87 '''Session_begin'''
    88 with tf.Session() as sess:
    89     sess.run(tf.global_variables_initializer())
    90     for i in range(1000):
    91         batch_xs,batch_ys=mnist.train.next_batch(100) #逐个batch去取数据
    92         sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
    93         if(i%50==0):
    94             print(compute_accuracy(mnist.test.images,mnist.test.labels))
    95             
  • 相关阅读:
    如何在tomcat安装部署php项目
    十大建站开源程序
    虚拟主机、VPS、云主机以及独立服务器的关系
    heritrix启动问题修正
    网页布局:float与position的区别
    C#中利用委托实现多线程跨线程操作
    Java Service Wrapper配置详解
    Windows7部署WordPress傻瓜式教程(IIS7.5+MySQL+PHP+WordPress)
    关于favicon.ico的使用
    使用JAVA对字符串进行DES加密解密(修正问题)
  • 原文地址:https://www.cnblogs.com/ChenKe-cheng/p/8889229.html
Copyright © 2011-2022 走看看