zoukankan      html  css  js  c++  java
  • tensorflow入门:CNN for MNIST

    在这里插入图片描述
    使用tensorflow构建如上图所示的CNN用于对MNIST数据集进行softmax classification。

    理论部分不再赘述,完整的代码如下:

    import tensorflow as tf
    import numpy as np
    
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    # hyperparameter
    learning_rate = 0.001
    training_epoches = 20
    batch_size = 100
    
    class Model:
        
        def __init__(self, sess, name):
            self.sess = sess
            self.name = name
            self._build_net()
            
        def _build_net(self):
            # with tf.variable_scope(self.name):
            self.training = tf.placeholder(tf.bool)
            # input placeholder for X & Y
            self.X = tf.placeholder(tf.float32, [None, 784])
            self.Y = tf.placeholder(tf.float32, [None, 10])
            # img 28x28x1 (black/white)
            X_img = tf.reshape(self.X, [-1, 28, 28, 1])
                
            # convolutional layer 1 & pooling layer 1
            conv1 = tf.layers.conv2d(inputs=X_img, filters=32, kernel_size=[3, 3], 
                                     padding="SAME", activation=tf.nn.relu)
            pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2],
                                            padding="SAME", strides=2)
            dropout1 = tf.layers.dropout(inputs=pool1, rate=0.3, training=self.training)
    
            # convolutional layer 2 & pooling layer 2
            conv2 = tf.layers.conv2d(inputs=dropout1, filters=64, kernel_size=[3, 3],
                                    padding="SAME", activation=tf.nn.relu)
            pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2],
                                           padding="SAME", strides=2)
            dropout2 = tf.layers.dropout(inputs=pool2, rate=0.3, training=self.training)
    
            # convolutional layer 3 & pooling layer 3
            conv3 = tf.layers.conv2d(inputs=dropout2, filters=128, kernel_size=[3, 3],
                                    padding="SAME", activation=tf.nn.relu)
            pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2],
                                           padding="SAME", strides=2)
            dropout3 = tf.layers.dropout(inputs=pool3, rate=0.3, training=self.training)
    
            # dense layer with Relu
            flat = tf.reshape(dropout3, [-1, 128 * 4 * 4])
            dense4 = tf.layers.dense(inputs=flat, units=625, activation=tf.nn.relu)
            dropout4 = tf.layers.dropout(inputs=dense4, rate=0.5, training=self.training)
    
            # FC layer 625 input -> 10 output, no activation function
            self.logits = tf.layers.dense(inputs=dropout4, units=10)
    
            # define loss & optimizer
            self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                                       logits = self.logits, labels=self.Y))
            self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.cost)
    
            # accuracy
            correct_prediction = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.Y, 1))
            self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            
        def train(self, x_data, y_data, training=True):
            return self.sess.run([self.cost, self.optimizer], 
                                 feed_dict={self.X: x_data, self.Y: y_data, self.training: training})
        
        def predict(self, x_test, training=False):
            return self.sess.run(self.logits, 
                                 feed_dict={self.X :x_test, self.training: training})
    
        def get_accuracy(self, x_test, y_test, training=False):
            return self.sess.run(self.accuracy, 
                                 feed_dict={self.X: x_test,self.Y: y_test, self.training: training})
        
    
    # train the models
    with tf.Session() as sess:
        models = []
        num_models = 2
        
        for m in range(num_models):
            models.append(Model(sess, "modal"+str(m)))
            
        sess.run(tf.global_variables_initializer())
        
        print('Learning Start!')
        
        for epoch in range(training_epoches):
            avg_cost_list = np.zeros(len(models))
            total_batch = int(mnist.train.num_examples / batch_size)
            for i in range(total_batch):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                
                # train each modal
                for m_id, m in enumerate(models):
                    
                    c, _ = m.train(batch_xs, batch_ys)
                    avg_cost_list[m_id] += c / total_batch
            
            print('Epoch: ', '%04d' %(epoch+1), 'cost=', avg_cost_list)
            
        print('Learning finished!')
    
        # test & accuracy                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
        test_size = len(mnist.test.labels)
        predictions = np.zeros([test_size, 10])
    
        for m_id,m in enumerate(models):
            print(m_id, "Accuracy:", m.get_accuracy(mnist.test.images, mnist.test.labels))
            p = m.predict(mnist.test.images)
            predictions += p
    
        ensemble_correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.argmax(mnist.test.labels, 1))
        ensemble_accuracy = tf.reduce_mean(tf.cast(ensemble_correct_prediction, tf.float32))
        print("Ensemble_accuracy:", sess.run(ensemble_accuracy))
    

    结果:

    Learning Start!
    Epoch:  0001 cost= [0.29211415 0.28355632]
    Epoch:  0002 cost= [0.08716567 0.0870499 ]
    Epoch:  0003 cost= [0.06902521 0.06623169]
    Epoch:  0004 cost= [0.05563359 0.05452387]
    Epoch:  0005 cost= [0.04963774 0.04871382]
    Epoch:  0006 cost= [0.04462749 0.04449957]
    Epoch:  0007 cost= [0.04132144 0.03907955]
    Epoch:  0008 cost= [0.03792324 0.03861412]
    Epoch:  0009 cost= [0.0354344  0.03323769]
    Epoch:  0010 cost= [0.03516847 0.03405525]
    Epoch:  0011 cost= [0.03143759 0.03219781]
    Epoch:  0012 cost= [0.03051504 0.02993162]
    Epoch:  0013 cost= [0.02906878 0.02711077]
    Epoch:  0014 cost= [0.02729127 0.02754832]
    Epoch:  0015 cost= [0.02729633 0.02632647]
    Epoch:  0016 cost= [0.02438517 0.02701174]
    Epoch:  0017 cost= [0.02482958 0.0244114 ]
    Epoch:  0018 cost= [0.02455271 0.02649499]
    Epoch:  0019 cost= [0.02371975 0.02178147]
    Epoch:  0020 cost= [0.02260135 0.0213784 ]
    Learning finished!
    0 Accuracy: 0.995
    1 Accuracy: 0.9949
    Ensemble_accuracy: 0.9954
    

    结果前面的其实有很长的warning,这里没有给出。warning是说新版本的tensorflow把mnist数据集移动到了别的地方,建议你从别的地方导入进来。这篇博文仅做例子。实际使用tensorflow的时候,你都是自己写读取数据的函数什么的,需要根据数据集的存储格式写不同的Python代码。

  • 相关阅读:
    iOS 使用GRMustache对HTML页面进行渲染
    算法 -- 排序
    ios 笔记
    ios 开发视图界面动态渲染
    Python环境变量设置
    Excel2010: Excel使用小技巧(不断更新)
    C: Answers to “The C programming language, Edition 2”
    VBScript: Windows脚本宿主介绍
    VBScript: 正则表达式(RegExp对象)
    VBScript Sample:遍历文件夹并获取XML文件中指定内容
  • 原文地址:https://www.cnblogs.com/wanghongze95/p/13842481.html
Copyright © 2011-2022 走看看