zoukankan      html  css  js  c++  java
  • Tensorflow入门实战-mnist手写体识别

     1 '''
     2 tensorflow 教程
     3 mnist样例
     4 '''
     5 import tensorflow as tf 
     6 from tensorflow.examples.tutorials.mnist import input_data
     7 
     8 #参数设置
     9 INPUT_NODE=784
    10 OUTPUT_NODE=10
    11 LAYER1_NODE=500
    12 BATCH_SIZE=100
    13 LEARNING_RATE_BASE=0.8
    14 LEARNING_RATE_DECAY=0.99
    15 REGULARIZATION_RATE=0.0001
    16 TRAINING_STEPS=10000
    17 MOVEING_AVEARGE_DECAY=0.99
    18 
    19 
    20 def inference(input_tensor,avg_class,weights1,biases1,weights2,biases2):
    21     '''
    22     定义前向计算的过程:
    23     avg_class是滑动平均函数,使权重平滑过渡,保留历史数据,
    24     为None时,表示普通的参数更新过程
    25     '''
    26     if avg_class==None:
    27         layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1)
    28         return tf.matmul(layer1,weights2)+biases2
    29     else:
    30         layer1=tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1)+avg_class.average(biases1)))
    31         return tf.matmul(layer1,avg_class.average(weights2)+avg_class.average(biases2))
    32 
    33 def train(mnist):
    34     #设置输入变量 placerholder表示占位,开启会话训练的时候需要传入数据
    35     x=tf.placeholder(tf.float32,[None,INPUT_NODE],name='x-input')
    36     y_=tf.placeholder(tf.float32,[None,OUTPUT_NODE],name='y-input')
    37 
    38     #设置权重变量,variable表示训练时需要自动更新
    39     weights1=tf.Variable(tf.random_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1))
    40     biases1=tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))
    41     weights2=tf.Variable(tf.random_normal([LAYER1_NODE,OUTPUT_NODE],stddev=0.1))
    42     biases2=tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE]))
    43     
    44     #y=inference(x,None,weights1,biases1,weights2,biases2)
    45     
    46     global_step=tf.Variable(0,trainable=False)#不可更新参数
    47     variable_averages=tf.train.ExponentialMovingAverage(MOVEING_AVEARGE_DECAY,global_step)#min(decay,(1+step)/(10+step)) 后面的变量会越来越大,表示参数的更新越来越稳定,大都依赖于历史数据
    48     variable_averages_op=variable_averages.apply(tf.trainable_variables())
    49     average_y=inference(x,variable_averages,weights1,biases1,weights2,biases2)
    50 
    51     cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=average_y,labels=tf.argmax(y_,1))#计算图的输出是每个分类的得分,但是要求输入的标签是正确答案的下标
    52     cross_entropy_mean=tf.reduce_mean(cross_entropy)
    53 
    54     regularizer=tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    55     regularization=regularizer(weights1)+regularizer(weights2)
    56     loss=cross_entropy+regularization
    57 
    58     learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY)#学习率成阶梯状衰减 每个epoch衰减一次,也就是一整轮数据训练完衰减一次
    59     train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
    60 
    61     train_op=tf.group(train_step,variable_averages_op)#把反向传播是需要更新的参数打包,不使用滑动平均不需要这句话,因为只更新权重。滑动平均还要利用历史数据更新并更新历史数据
    62 
    63     correct_prediction=tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1))
    64     accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    65     
    66 
    67     with tf.Session() as sess:
    68         tf.global_variables_initializer().run()
    69         validate_feed={x:mnist.validation.images,y_:mnist.validation.labels}
    70         test_feed={x:mnist.test.images,y_:mnist.test.labels}
    71 
    72         for i in range(TRAINING_STEPS):
    73             if i%1000==0:
    74                 validate_acc=sess.run(accuracy,feed_dict=validate_feed)
    75                 print('After %d training steps,validation accuracy using average model is %g' %(i,validate_acc))
    76                 
    77             xs,ys=mnist.train.next_batch(BATCH_SIZE)
    78             sess.run(train_op,feed_dict={x:xs,y_:ys})
    79 
    80         test_acc=sess.run(accuracy,feed_dict=test_feed)
    81         print('After %d training steps,test accuracy using average model is %g' %(TRAINING_STEPS,test_acc))    
    82 
    83 def main(argv=None):
    84     mnist=input_data.read_data_sets("/tmp/data",one_hot=True)
    85     train(mnist)
    86 
    87 if __name__ == '__main__':
    88     tf.app.run()
  • 相关阅读:
    gitLab 全局hooks和custom_hooks,以及服务器端自动更新和备份(三)
    ORACLE的Copy命令和create table,insert into的比较
    计算机基础
    在C#应用中使用Common Logging日志接口
    数据库设计原则(转载)
    Oracle中函数如何返回结果集
    ORACLE时间常用函数(字段取年、月、日、季度)
    SQLServer2005 没有日志文件(*.ldf) 只有数据文件(*.mdf) 恢复数据库的方法
    sql server日期时间转字符串
    SQL Server删除用户失败的解决方法
  • 原文地址:https://www.cnblogs.com/super-JJboom/p/9596848.html
Copyright © 2011-2022 走看看