zoukankan      html  css  js  c++  java
  • Tensorflow学习—— AdamOptimizer

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data

    #载入数据集
    mnist = input_data.read_data_sets("F:\TensorflowProject\MNIST_data",one_hot=True)

    #每个批次的大小,训练时一次100张放入神经网络中训练
    batch_size = 100

    #计算一共有多少个批次
    n_batch = mnist.train.num_examples//batch_size

    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    #0-9十个数字
    y = tf.placeholder(tf.float32,[None,10])

    #创建一个神经网络
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    prediction = tf.nn.softmax(tf.matmul(x,W)+b)

    #二次代价函数
    #loss = tf.reduce_mean(tf.square(y-prediction))
    #交叉熵代价函数
    #使用交叉熵定义代价函数,可以加快模型收敛速度
    #loss = tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
    #使用梯度下降法
    #train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    train_step = tf.train.AdamOptimizer(0.01).minimize(loss) #1e-2


    #初始化变量
    init = tf.global_variables_initializer()

    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    #
    with tf.Session() as sess:
      sess.run(init)
      for epoch in range(21):
        for batch in range(n_batch):
          batch_xs,batch_ys = mnist.train.next_batch(batch_size)
          sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

        #测试准确率
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter: "+str(epoch)+" ,Testing Accuracy "+str(acc))

    ###########运行结果

    Extracting F:TensorflowProjectMNIST_data	rain-images-idx3-ubyte.gz
    Extracting F:TensorflowProjectMNIST_data	rain-labels-idx1-ubyte.gz
    Extracting F:TensorflowProjectMNIST_data	10k-images-idx3-ubyte.gz
    Extracting F:TensorflowProjectMNIST_data	10k-labels-idx1-ubyte.gz
    Iter: 0  ,Testing Accuracy  0.9221
    Iter: 1  ,Testing Accuracy  0.9133
    Iter: 2  ,Testing Accuracy  0.9271
    Iter: 3  ,Testing Accuracy  0.9262
    Iter: 4  ,Testing Accuracy  0.9299
    Iter: 5  ,Testing Accuracy  0.9293
    Iter: 6  ,Testing Accuracy  0.9301
    Iter: 7  ,Testing Accuracy  0.9299
    Iter: 8  ,Testing Accuracy  0.9287
    Iter: 9  ,Testing Accuracy  0.9319
    Iter: 10  ,Testing Accuracy  0.9317
    Iter: 11  ,Testing Accuracy  0.9315
    Iter: 12  ,Testing Accuracy  0.9307
    Iter: 13  ,Testing Accuracy  0.932
    Iter: 14  ,Testing Accuracy  0.9314
    Iter: 15  ,Testing Accuracy  0.9316
    Iter: 16  ,Testing Accuracy  0.9311
    Iter: 17  ,Testing Accuracy  0.9333
    Iter: 18  ,Testing Accuracy  0.9318
    Iter: 19  ,Testing Accuracy  0.9318
    Iter: 20  ,Testing Accuracy  0.9289
  • 相关阅读:
    HBase 高性能加入数据
    Please do not register multiple Pages in undefined.js 小程序报错的几种解决方案
    小程序跳转时传多个参数及获取
    vue项目 调用百度地图 BMap is not defined
    vue生命周期小笔记
    解决小程序背景图片在真机上不能查看的问题
    vue项目 菜单侧边栏随着右侧内容盒子的高度实时变化
    vue项目 一行js代码搞定点击图片放大缩小
    微信小程序进行地图导航使用地图功能
    小程序报错Do not have xx handler in current page的解决方法
  • 原文地址:https://www.cnblogs.com/herd/p/9467849.html
Copyright © 2011-2022 走看看