zoukankan      html  css  js  c++  java
  • tensorflow实现mnist

    import tensorflow as tf
    
    from tensorflow.examples.tutorials.mnist import input_data
    
    # 在变量的构建时,通过truncated_normal 函数初始化权重变量,shape 是一个二维的tensor,从截断的正太分布中输出随机值。
    def weight_varible(shape):
        initial=tf.truncated_normal(shape,stddev=0.1)
        return tf.Variable(initial)
    
    def bias_varible(shape):
        initial=tf.constant(0.1,shape=shape)
        return tf.Variable(initial)
        
    def conv2d(x,w):
        # x 输入,是一个tensor,w 为filter,strides卷积时每一维的步长,padding 参数是string类型的值,为same,或valid,use_cudnn_on_gpu: bool类型,默认是true
        return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')
    def max_pool_2x2(x):
        # x 为池化的输入,ksize 为窗口的大小,四维向量,一般为【1,height,width,1】, 因为不在channel上做池化,stride 和卷积类型,窗口在每一维上滑动的步长,【1,stride,stride,1】
        return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    
    # tensorflow 中已经写好了加载mnist 数据集的脚本,minist是一个轻量级的类文件,存储了Numpy格式的训练集,验证集。。同时提供了数据中mini-batch迭代的功能
    mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
    
    # 输入数据参数
    x=tf.placeholder(tf.float32,[None,784])
    # 当某轴为-1 时,根据数组元素的个数自动计算此轴的长度
    x_image=tf.reshape(x,[-1,28,28,1])
    y=tf.placeholder(tf.float32,[None,10])
    
    # 后台C++,通过session与后台连接, 在session中运行创建的图,
    # 使用InteractiveSession 更方便,允许交互操作,如果不用InteractiveSession, 在启动一个回话和运行图之前,要创建整个流图
    sess=tf.InteractiveSession()
    
    
    # 第一个卷积层
    w_conv1=weight_varible([5,5,1,32])
    b_conv1=bias_varible([32])
    h_conv1=tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
    h_pool1=max_pool_2x2(h_conv1)
    
    # 第二个卷积层
    w_conv2=weight_varible([5,5,32,64])
    b_conv2=bias_varible([64])
    h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
    h_pool2=max_pool_2x2(h_conv2)
    
    #全连接层
    # 经过两次pool 之后大小变为7*7
    w_fc1=weight_varible([7*7*64,1024])
    b_fc1=bias_varible([1024])
    
    h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64]);
    h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
    
    # dropout 层
    keep_prob=tf.placeholder(tf.float32)
    h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
    
    #output softmax
    
    w_fc2=weight_varible([1024,10])
    b_fc2=bias_varible([10])
    y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)
    
    # 定义loss 函数,和最优化函数
    cross_entropy=-tf.reduce_sum(y*tf.log(y_conv))
    train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    
    
    # 定义eval
    correct_prediction=tf.equal(tf.arg_max(y_conv,1),tf.arg_max(y,1))
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# cast 为类型转换函数,将类型转换为tf.float32
    
    sess.run(tf.initialize_all_variables())
    
    # 开始正式的训练
    for i in range(20000):
        batch=mnist.train.next_batch(50)
        # mnist 类中提供了取batch的函数
        if i%100==0:
            
           # train_accuracy=accuracy.eval()
            print("step %d, training accuracy %g" %(i, sess.run(accuracy,feed_dict={x:batch[0],y:batch[1],keep_prob:1.0})))
            
        sess.run(train_step,feed_dict={x:batch[0],y:batch[1],keep_prob:0.5})
        #train_step.run()
        
    print("test accuracy %g" %(accuracy.eval(feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})))
    

      

    sess.run(accuracy,feed_dict={x:batch[0],y:batch[1],keep_prob:1.0})

    accuracy.eval(feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})

    accuracy.eval 相当于 tf.get_default_session().run(t)

    accuracy.eval()==sess.run(t)

    而sess.run() 可以在同一步骤获取更多的张量的值,: sess.run(accuracy, train_step)

    sess.run 和eval 每次都从头执行graph, 要缓存计算结果,需要分配tf.Variable

    注意注意: 我们 需要关闭回话
    sess.close()
    如果想不显式的调用close, 可以用with 代码块
    with tf.Session() as sess:
        ...
    

      如果需要存储参数:

    saver=tf.train.Saver()
    save_path=saver.save(sess,model_path)
    load_path=saver.restore(sess,model_path)
    

      





  • 相关阅读:
    桟错误分析方法
    gstreamer调试命令
    sqlite的事务和锁,很透彻的讲解 【转】
    严重: Exception starting filter struts2 java.lang.NullPointerException (转载)
    eclipse 快捷键
    POJ 1099 Square Ice
    HDU 1013 Digital Roots
    HDU 1087 Super Jumping! Jumping! Jumping!(动态规划)
    HDU 1159 Common Subsequence
    HDU 1069 Monkey and Banana(动态规划)
  • 原文地址:https://www.cnblogs.com/fanhaha/p/7258047.html
Copyright © 2011-2022 走看看