zoukankan      html  css  js  c++  java
  • TensorFlow实现Softmax回归【模型存储与加载】

     Softmax

    一.Softmax回归简介

      案例:MNIST手写数字识别

      1.为了得到一张给定图片属于某个特定数字类的证据【evidence】,对图片像素进行加权求和。如果这个像素具有很强的证据说明这张图片不属于该类,那么相应的权值为负值相反如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值即为正数。

        

      如下图,红色代表负数值,蓝色代表正数值:

        

      2.这里的softmax可以看做一个激励【activation】函数或者链接【link】函数,把我们定义的线性函数的输出转化成我们想要的格式,也就是关于10个数字类别的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被softmax函数转化成一个概率值。

        

        

      展开等式右边的子式:

        

      3.softmax把输入值当成幂指数求值,再正则化这些结果值。这个幂运算表示,更大的证据对应更大的假设模型【hypothesis】里面的乘数权重值。反之拥有更少的证据意味着在假设模型里面拥有更小的乘数系数。假设模型里面的权值不可以是小于0的数值。Softmax会正则化这些权重值,使它们的总和等于1,以此构造一个有效的概率分布。

        

      如果把它写成一个等式:

        

      转化为矩阵乘和向量加:

        

      转化为公式:

        

    二.代码实现

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Thu Oct 18 18:02:26 2018
     4 
     5 @author: zhen
     6 """
     7 
     8 from tensorflow.examples.tutorials.mnist import input_data
     9 import tensorflow as tf
    10 
    11 # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
    12 my_mnist = input_data.read_data_sets("C:/Users/zhen/MNIST_data_bak/", one_hot=True)
    13 
    14 # The MNIST data is split into three parts:
    15 # 55,000 data points of training data (mnist.train)
    16 # 10,000 points of test data (mnist.test), and
    17 # 5,000 points of validation data (mnist.validation).
    18 
    19 # Each image is 28 pixels by 28 pixels
    20 
    21 # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
    22 # 所以输入的矩阵是None乘以784二维矩阵
    23 x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
    24 # 初始化都是0,二维矩阵784乘以10个W值
    25 W = tf.Variable(tf.zeros([784, 10]))
    26 b = tf.Variable(tf.zeros([10]))
    27 
    28 y = tf.nn.softmax(tf.matmul(x, W) + b)
    29 
    30 # 训练
    31 # labels是每张图片都对应一个one-hot的10个值的向量
    32 y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
    33 # 定义损失函数,交叉熵损失函数
    34 # 对于多分类问题,通常使用交叉熵损失函数
    35 # reduction_indices等价于axis,指明按照每行加,还是按照每列加
    36 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
    37                                               reduction_indices=[1]))
    38 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    39 
    40 # 评估
    41 
    42 # tf.argmax()是一个从tensor中寻找最大值的序号,tf.argmax就是求各个预测的数字中概率最大的那一个
    43 
    44 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    45 
    46 # 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
    47 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    48 
    49 # 初始化变量
    50 sess = tf.InteractiveSession()
    51 tf.global_variables_initializer().run()
    52 # 创建Saver节点,用于保存训练的模型
    53 saver = tf.train.Saver()
    54 for i in range(100):
    55     batch_xs, batch_ys = my_mnist.train.next_batch(100)
    56     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    57     # 每隔一段时间保存一次中间结果
    58     if i % 10 == 0:
    59         save_path = saver.save(sess, "C:/Users/zhen/MNIST_data_bak/saver/softmax_middle_model.ckpt")
    60     
    61     # print("TrainSet batch acc : %s " % accuracy.eval({x: batch_xs, y_: batch_ys}))
    62     # print("ValidSet acc : %s" % accuracy.eval({x: my_mnist.validation.images, y_: my_mnist.validation.labels}))
    63 
    64 # 测试
    65 print("TestSet acc : %s" % accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))
    66 # 保存最终的模型
    67 save_path = saver.save(sess, "C:/Users/zhen/MNIST_data_bak/saver/softmax_final_model.ckpt")
    68 
    69 # 使用训练好的模型直接进行预测
    70 with tf.Session() as sess_back:
    71     saver.restore(sess_back, "C:/Users/zhen/MNIST_data_bak/saver/softmax_final_model.ckpt")
    72     # 评估
    73     correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    74     accruary = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    75     # 测试
    76     print(accuracy.eval({x : my_mnist.test.images, y_ : my_mnist.test.labels}))
    77 # 总结
    78 # 1,定义算法公式,也就是神经网络forward时的计算
    79 # 2,定义loss,选定优化器,并指定优化器优化loss
    80 # 3,迭代地对数据进行训练
    81 # 4,在测试集或验证集上对准确率进行评测

    三.结果

        

      

    四.解析

      把训练好的模型存储落地磁盘,有利于多次使用和共享,也便于当训练出现异常时能恢复模型而不是重新训练!

  • 相关阅读:
    队列分类梳理
    停止线程
    Docker和Kubernetes
    Future、Callback、Promise
    Static、Final、static final
    线程池梳理
    TCP四次挥手
    http1.0、http1.x、http 2和https梳理
    重排序
    java内存模型梳理
  • 原文地址:https://www.cnblogs.com/yszd/p/9822365.html
Copyright © 2011-2022 走看看