zoukankan      html  css  js  c++  java
  • tensorflow实现mnist

    import tensorflow as tf
    
    from tensorflow.examples.tutorials.mnist import input_data
    
    # 在变量的构建时,通过truncated_normal 函数初始化权重变量,shape 是一个二维的tensor,从截断的正太分布中输出随机值。
    def weight_varible(shape):
        initial=tf.truncated_normal(shape,stddev=0.1)
        return tf.Variable(initial)
    
    def bias_varible(shape):
        initial=tf.constant(0.1,shape=shape)
        return tf.Variable(initial)
        
    def conv2d(x,w):
        # x 输入,是一个tensor,w 为filter,strides卷积时每一维的步长,padding 参数是string类型的值,为same,或valid,use_cudnn_on_gpu: bool类型,默认是true
        return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')
    def max_pool_2x2(x):
        # x 为池化的输入,ksize 为窗口的大小,四维向量,一般为【1,height,width,1】, 因为不在channel上做池化,stride 和卷积类型,窗口在每一维上滑动的步长,【1,stride,stride,1】
        return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    
    # tensorflow 中已经写好了加载mnist 数据集的脚本,minist是一个轻量级的类文件,存储了Numpy格式的训练集,验证集。。同时提供了数据中mini-batch迭代的功能
    mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
    
    # 输入数据参数
    x=tf.placeholder(tf.float32,[None,784])
    # 当某轴为-1 时,根据数组元素的个数自动计算此轴的长度
    x_image=tf.reshape(x,[-1,28,28,1])
    y=tf.placeholder(tf.float32,[None,10])
    
    # 后台C++,通过session与后台连接, 在session中运行创建的图,
    # 使用InteractiveSession 更方便,允许交互操作,如果不用InteractiveSession, 在启动一个回话和运行图之前,要创建整个流图
    sess=tf.InteractiveSession()
    
    
    # 第一个卷积层
    w_conv1=weight_varible([5,5,1,32])
    b_conv1=bias_varible([32])
    h_conv1=tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
    h_pool1=max_pool_2x2(h_conv1)
    
    # 第二个卷积层
    w_conv2=weight_varible([5,5,32,64])
    b_conv2=bias_varible([64])
    h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
    h_pool2=max_pool_2x2(h_conv2)
    
    #全连接层
    # 经过两次pool 之后大小变为7*7
    w_fc1=weight_varible([7*7*64,1024])
    b_fc1=bias_varible([1024])
    
    h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64]);
    h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
    
    # dropout 层
    keep_prob=tf.placeholder(tf.float32)
    h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
    
    #output softmax
    
    w_fc2=weight_varible([1024,10])
    b_fc2=bias_varible([10])
    y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)
    
    # 定义loss 函数,和最优化函数
    cross_entropy=-tf.reduce_sum(y*tf.log(y_conv))
    train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    
    
    # 定义eval
    correct_prediction=tf.equal(tf.arg_max(y_conv,1),tf.arg_max(y,1))
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))# cast 为类型转换函数,将类型转换为tf.float32
    
    sess.run(tf.initialize_all_variables())
    
    # 开始正式的训练
    for i in range(20000):
        batch=mnist.train.next_batch(50)
        # mnist 类中提供了取batch的函数
        if i%100==0:
            
           # train_accuracy=accuracy.eval()
            print("step %d, training accuracy %g" %(i, sess.run(accuracy,feed_dict={x:batch[0],y:batch[1],keep_prob:1.0})))
            
        sess.run(train_step,feed_dict={x:batch[0],y:batch[1],keep_prob:0.5})
        #train_step.run()
        
    print("test accuracy %g" %(accuracy.eval(feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})))
    

      

    sess.run(accuracy,feed_dict={x:batch[0],y:batch[1],keep_prob:1.0})

    accuracy.eval(feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})

    accuracy.eval 相当于 tf.get_default_session().run(t)

    accuracy.eval()==sess.run(t)

    而sess.run() 可以在同一步骤获取更多的张量的值,: sess.run(accuracy, train_step)

    sess.run 和eval 每次都从头执行graph, 要缓存计算结果,需要分配tf.Variable

    注意注意: 我们 需要关闭回话
    sess.close()
    如果想不显式的调用close, 可以用with 代码块
    with tf.Session() as sess:
        ...
    

      如果需要存储参数:

    saver=tf.train.Saver()
    save_path=saver.save(sess,model_path)
    load_path=saver.restore(sess,model_path)
    

      





  • 相关阅读:
    lxml库
    requests库基本使用
    Xpath Helper的使用
    Class.forName()的作用(转)
    JDBC 连接数据库
    IDEA 的 Othere Settings(Default settings)消失了?(转)
    servletContext.getRealPath(String)作用(转)
    MySQL中插入相关
    MyBatis 中错误信息详情、原因分析及解决方案
    Java 的全限定类名
  • 原文地址:https://www.cnblogs.com/fanhaha/p/7258047.html
Copyright © 2011-2022 走看看