zoukankan      html  css  js  c++  java
  • 吴裕雄--天生自然TensorFlow高层封装:使用TensorFlow-Slim处理MNIST数据集实现LeNet-5模型

    # 1. 通过TensorFlow-Slim定义卷机神经网络
    import numpy as np
    import tensorflow as tf
    import tensorflow.contrib.slim as slim
    
    from tensorflow.examples.tutorials.mnist import input_data
    
    # 通过TensorFlow-Slim来定义LeNet-5的网络结构。
    def lenet5(inputs):
        inputs = tf.reshape(inputs, [-1, 28, 28, 1])
        net = slim.conv2d(inputs, 32, [5, 5], padding='SAME', scope='layer1-conv')
        net = slim.max_pool2d(net, 2, stride=2, scope='layer2-max-pool')
        net = slim.conv2d(net, 64, [5, 5], padding='SAME', scope='layer3-conv')
        net = slim.max_pool2d(net, 2, stride=2, scope='layer4-max-pool')
        net = slim.flatten(net, scope='flatten')
        net = slim.fully_connected(net, 500, scope='layer5')
        net = slim.fully_connected(net, 10, scope='output')
        return net
    
    # 训练模型。
    def train(mnist):
        x = tf.placeholder(tf.float32, [None, 784], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
        y = lenet5(x)
    
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
        loss = tf.reduce_mean(cross_entropy)
    
        train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            for i in range(3000):
                xs, ys = mnist.train.next_batch(100)
                _, loss_value = sess.run([train_op, loss], feed_dict={x: xs, y_: ys})
                if i % 1000 == 0:
                    print("After %d training step(s), loss on training batch is %g." % (i, loss_value))
    #  主程序
    def main(argv=None):
        mnist = input_data.read_data_sets("F:\TensorFlowGoogle\201806-github\datasets\MNIST_data\", one_hot=True)
        train(mnist)
    
    if __name__ == '__main__':
        main()

  • 相关阅读:
    Tsinghua 2018 DSA PA3简要题解
    Tsinghua 2018 DSA PA2简要题解
    Python logging系统
    Surface RT2使用情况
    隔壁信概大作业xjb写——同化棋ATAXX
    XJTUOJ #1080 qz的不卡常数
    XJTUOJ #1081 JM的赃物被盗
    XJTUOJ #1078 JM的恶有恶报
    洛谷P5425 [USACO19OPEN]I Would Walk 500 Miles G
    洛谷P2857 [USACO06FEB]Steady Cow Assignment G
  • 原文地址:https://www.cnblogs.com/tszr/p/12069987.html
Copyright © 2011-2022 走看看