zoukankan      html  css  js  c++  java
  • MNIST数据集手写体识别(MLP实现)

    github博客传送门
    csdn博客传送门

    本章所需知识:

    1. 没有基础的请观看深度学习系列视频
    2. tensorflow
    3. Python基础

    资料下载链接:

    1. 深度学习基础网络模型(mnist手写体识别数据集)

    MNIST数据集手写体识别(MLP实现)

    import tensorflow as tf
    import tensorflow.examples.tutorials.mnist.input_data as input_data  # 导入下载数据集手写体
    mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True)
    
    
    class MLPNet:  # 创建一个MLPNet类
        def __init__(self):
            self.x = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='input_x')  # 创建一个tensorflow占位符(稍后传入图片数据),定义数据类型为tf.float32,形状shape为 None为批次 784为数据集撑开的 28*28的手写体图片 name可选参数
            self.y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='input_label')  # 创建一个tensorflow占位符(稍后传入图片标签), name可选参数
    
            self.w1 = tf.Variable(tf.truncated_normal(shape=[784, 100], dtype=tf.float32, stddev=tf.sqrt(1 / 100)))  # 定义全链接 第一层/输入层 变量w/神经元w/特征数w 为截断正态分布 形状shape为[784, 100](100个神经元个数的矩阵), stddev(标准差的值 一般情况为sqrt(1/当前神经元个数))
            self.b1 = tf.Variable(tf.zeros([100], dtype=tf.float32))  # 定义变量偏值b 值为零 形状shape为 [100] 当前层w 的个数, 数据类型 dtype为tf.float32
    
            self.w2 = tf.Variable(tf.truncated_normal(shape=[100, 10], dtype=tf.float32, stddev=tf.sqrt(1 / 10)))  # 定义全链接 第二层/输出层 变量w/神经元w/特征数w 为截断正态分布 形状为[上一层神经元个数, 当前输出神经元个数] 由于我们是识别十分类问题,手写体 0,1,2,3,4,5,6,7,8,9所以我们选择十个神经元 stddev同上
            self.b2 = tf.Variable(tf.zeros([10], dtype=tf.float32))  # 设置原理同上
    
    	# 前向计算
        def forward(self):
            self.forward_1 = tf.nn.relu(tf.matmul(self.x, self.w1) + self.b1)  # 全链接第一层
            self.forward_2 = tf.nn.relu(tf.matmul(self.forward_1, self.w2) + self.b2)  # 全链接第二层
            self.output = tf.nn.softmax(self.forward_2)  # softmax分类器分类
    	
    	# 后向计算
        def backward(self):
            self.cost = tf.reduce_mean(tf.square(self.output - self.y))  # 定义均方差损失
            self.opt = tf.train.AdamOptimizer().minimize(self.cost)      # 使用AdamOptimizer优化器 优化 self.cost损失函数
    
    	# 计算识别精度
        def acc(self):
    		# 将预测值 output 和 标签值 self.y 进行比较
            self.z = tf.equal(tf.argmax(self.output, 1, name='output_max'), tf.argmax(self.y, 1, name='y_max'))
            # 最后对比较出来的bool值 转换为float32类型后 求均值就可以看到满值为 1的精度显示
    		self.accaracy = tf.reduce_mean(tf.cast(self.z, tf.float32))
    
    
    if __name__ == '__main__':
        net = MLPNet()  # 启动tensorflow绘图的MLPNet
        net.forward()   # 启动前向计算
        net.backward()  # 启动后向计算
        net.acc()       # 启动精度计算
        init = tf.global_variables_initializer()  # 定义初始化tensorflow所有变量操作
        with tf.Session() as sess:                # 创建一个Session会话
            sess.run(init)                        # 执行init变量内的初始化所有变量的操作
            for i in range(10000):                # 训练10000次
                ax, ay = mnist.train.next_batch(100)  # 从mnist数据集中取数据出来 ax接收图片 ay接收标签
                loss, accaracy, _ = sess.run(fetches=[net.cost, net.accaracy, net.opt], feed_dict={net.x: ax, net.y: ay})  # 将数据喂进神经网络(以字典的方式传入) 接收loss返回值
                if i % 1000 == 0:  # 每训练1000次
                    test_ax, test_ay = mnist.test.next_batch(100)  # 则使用测试集对当前网络进行测试
                    test_output = sess.run(net.output, feed_dict={net.x: test_ax})  # 将测试数据喂进网络 接收一个output值
                    z = tf.equal(tf.argmax(test_output, 1, name='output_max'), tf.argmax(test_ay, 1, name='test_y_max'))  # 对output值和标签y值进行求比较运算
                    accaracy2 = sess.run(tf.reduce_mean(tf.cast(z, tf.float32)))  # 求出精度的准确率进行打印
                    print(accaracy2)  # 打印当前测试集的精度 
    

    最后附上训练截图:

    MLP

  • 相关阅读:
    MyBatis3: There is no getter for property named 'code' in 'class java.lang.String'
    jQuery获取Select选择的Text和 Value(转)
    mybatis3 :insert返回插入的主键(selectKey)
    【转】Mybatis/Ibatis,数据库操作的返回值
    Android问题-打开DelphiXE8与DelphiXE10编译空工程提示“[Exec Error] The command exited with code 1.”
    Android问题-打开DelphiXE8与DelphiXE10新建一个空工程提示"out of memory"
    BAT-使用BAT生成快捷方式
    给 TTreeView 添加复选框
    跨进程发送消息数据
    鼠标拖动虚影效果
  • 原文地址:https://www.cnblogs.com/Mrzhang3389/p/9899036.html
Copyright © 2011-2022 走看看