zoukankan      html  css  js  c++  java
  • 基于多层感知机的手写数字识别(Tensorflow实现)

    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import os
    
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    class MNISTModel(object):
        def __init__(self, lr, batch_size, iter_num):
            self.lr = lr
            self.batch_size = batch_size
            self.iter_num = iter_num
            # 定义模型结构
            # 输入张量,这里还没有数据,先占个地方,所以叫“placeholder”
            self.x = tf.placeholder(tf.float32, [None, 784])   # 图像是28*28的大小
            self.y = tf.placeholder(tf.float32, [None, 10])    # 输出是0-9的one-hot向量
            self.h = tf.layers.dense(self.x, 100, activation=tf.nn.relu, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 一个全连接层
            self.y_ = tf.layers.dense(self.h, 10, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 全连接层
            
            # 使用交叉熵损失函数
            self.loss = tf.losses.softmax_cross_entropy(self.y, self.y_)
            self.optimizer = tf.train.AdamOptimizer()
            self.train_step = self.optimizer.minimize(self.loss)
            
            # 用于模型训练
            self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(self.y_, axis=1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
            
            # 用于保存训练好的模型
            self.saver = tf.train.Saver()
            
        def train(self):
            with tf.Session() as sess:            #  打开一个会话。可以想象成浏览器打开一个标签页一样,直观地理解一下
                sess.run(tf.global_variables_initializer())  # 先初始化所有变量。
                for i in range(self.iter_num):
                    batch_x, batch_y = mnist.train.next_batch(self.batch_size)   # 读取一批数据
                    loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.x: batch_x, self.y: batch_y})   # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
                    if i%1000 == 0:    
                        train_accuracy = sess.run(self.accuracy, feed_dict={self.x: batch_x, self.y: batch_y})  # 把训练集数据装填进去
                        test_x, test_y = mnist.test.next_batch(self.batch_size)
                        test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})   # 把测试集数据装填进去
                        print( 'iter	%i	loss	%f	train_accuracy	%f	test_accuracy	%f' % (i,loss,train_accuracy,test_accuracy))
                self.saver.save(sess, 'model/mnistModel') # 保存模型
    
        def test(self):
            with tf.Session() as sess:
                self.saver.restore(sess, 'model/mnistModel')
                Accuracy = []
                for i in range(150):
                    test_x, test_y = mnist.test.next_batch(self.batch_size)
                    test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})
                    Accuracy.append(test_accuracy)
                print ('==' * 15)
                print ('Test Accuracy: ', np.mean(np.array(Accuracy)))
    
    model = MNISTModel(0.001, 64, 40000)   # 学习率为0.001,每批传入64张图,训练40000次
    model.train()      # 训练模型
    model.test()       #测试模型
    
  • 相关阅读:
    AlexNet详解3
    ReLU为什么比Sigmoid效果好
    AlexNet详解2
    AlexNet详解
    微波炉蒸馄饨
    FM与PM信号的表现形式
    HTML与CSS:结构与表现
    CentOS 7安装WordPress
    nginx gzip配置
    minIni: A minimal INI file parser
  • 原文地址:https://www.cnblogs.com/shayue/p/10386107.html
Copyright © 2011-2022 走看看