zoukankan      html  css  js  c++  java
  • 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集

    一、前述

    本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类。

    同时对模型的保存和恢复做下示例。

    二、具体原理

    代码一:实现代码

    #!/usr/bin/python
    # -*- coding: UTF-8 -*-
    # 文件名: 12_Softmax_regression.py
    
    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf
    
    
    # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
    my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)#从本地路径加载进来
    
    # The MNIST data is split into three parts:
    # 55,000 data points of training data (mnist.train)#训练集图片
    # 10,000 points of test data (mnist.test), and#测试集图片
    # 5,000 points of validation data (mnist.validation).#验证集图片
    
    # Each image is 28 pixels by 28 pixels
    
    # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
    # 所以输入的矩阵是None乘以784二维矩阵
    x = tf.placeholder(dtype=tf.float32, shape=(None, 784)) #x矩阵是m行*784列
    # 初始化都是0,二维矩阵784乘以10个W值 #初始值最好不为0
    W = tf.Variable(tf.zeros([784, 10]))#W矩阵是784行*10列
    b = tf.Variable(tf.zeros([10]))#bias也必须有10个
    
    y = tf.nn.softmax(tf.matmul(x, W) + b)# x*w 即为m行10列的矩阵就是y #预测值
    
    # 训练
    # labels是每张图片都对应一个one-hot的10个值的向量
    y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))#真实值 m行10列
    # 定义损失函数,交叉熵损失函数
    # 对于多分类问题,通常使用交叉熵损失函数
    # reduction_indices等价于axis,指明按照每行加,还是按照每列加
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
                                                  reduction_indices=[1]))#指明按照列加和 一列是一个类别
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#将损失函数梯度下降 #0.5是学习率
    
    
    # 初始化变量
    sess = tf.InteractiveSession()#初始化Session
    tf.global_variables_initializer().run()#初始化所有变量
    for _ in range(1000):
        batch_xs, batch_ys = my_mnist.train.next_batch(100)#每次迭代取100行数据
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    #每次迭代内部就是求梯度,然后更新参数
    # 评估
    
    # tf.argmax()是一个从tensor中寻找最大值的序号 就是分类号,tf.argmax就是求各个预测的数字中概率最大的那一个
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    
    # 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    # 测试
    print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))
    
    # 总结
    # 1,定义算法公式,也就是神经网络forward时的计算
    # 2,定义loss,选定优化器,并指定优化器优化loss
    # 3,迭代地对数据进行训练
    # 4,在测试集或验证集上对准确率进行评测

    代码二:保存模型

    # 有时候需要把模型保持起来,有时候需要做一些checkpoint在训练中
    # 以致于如果计算机宕机,我们还可以从之前checkpoint的位置去继续
    # TensorFlow使得我们去保存和加载模型非常方便,仅需要去创建Saver节点在构建阶段最后
    # 然后在计算阶段去调用save()方法
    
    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf
    
    
    # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
    my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)
    
    # The MNIST data is split into three parts:
    # 55,000 data points of training data (mnist.train)
    # 10,000 points of test data (mnist.test), and
    # 5,000 points of validation data (mnist.validation).
    
    # Each image is 28 pixels by 28 pixels
    
    # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
    # 所以输入的矩阵是None乘以784二维矩阵
    x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
    # 初始化都是0,二维矩阵784乘以10个W值
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    
    # 训练
    # labels是每张图片都对应一个one-hot的10个值的向量
    y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
    # 定义损失函数,交叉熵损失函数
    # 对于多分类问题,通常使用交叉熵损失函数
    # reduction_indices等价于axis,指明按照每行加,还是按照每列加
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
                                                  reduction_indices=[1]))
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    
    # 初始化变量
    init = tf.global_variables_initializer()
    # 创建Saver()节点
    saver = tf.train.Saver()#在运算之前,初始化之后
    
    n_epoch = 1000
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(n_epoch):
            if epoch % 100 == 0:
                save_path = saver.save(sess, "./my_model.ckpt")#每跑100次save一次模型,可以保证容错性
                #直接保存session即可。
    
            batch_xs, batch_ys = my_mnist.train.next_batch(100)#每一批次跑的数据 用m行数据/迭代次数来计算出来。
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    
        best_theta = W.eval()
        save_path = saver.save(sess, "./my_model_final.ckpt")#保存最后的模型,session实际上保存的上面所有的数据

    代码三:恢复模型

    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf
    
    
    # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
    my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)
    
    # The MNIST data is split into three parts:
    # 55,000 data points of training data (mnist.train)
    # 10,000 points of test data (mnist.test), and
    # 5,000 points of validation data (mnist.validation).
    
    # Each image is 28 pixels by 28 pixels
    
    # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
    # 所以输入的矩阵是None乘以784二维矩阵
    x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
    # 初始化都是0,二维矩阵784乘以10个W值
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    # labels是每张图片都对应一个one-hot的10个值的向量
    y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        saver.restore(sess, "./my_model_final.ckpt")#把路径下面所有的session的数据加载进来 y y_head还有模型都保存下来了。
    
        # 评估
        # tf.argmax()是一个从tensor中寻找最大值的序号,tf.argmax就是求各个预测的数字中概率最大的那一个
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    
        # 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
        # 测试
        print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))
  • 相关阅读:
    获得当前python解释器的路径
    AirtestIDE
    大数据到底有多大?TB、PB、EB到底是多少?
    时间的单位有
    windows10 彻底关闭自动更新
    Microsoft Windows10系统时间显示秒的方法
    host文件路径(Windows)
    Mina学习之IoHandler
    Mina学习之IoFilter
    Mina学习之IoSession
  • 原文地址:https://www.cnblogs.com/LHWorldBlog/p/8661434.html
Copyright © 2011-2022 走看看