zoukankan      html  css  js  c++  java
  • 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)

    主要内容:
    1.基于多层感知器的mnist手写数字识别(代码注释)
    2.该实现中的函数总结

    平台:
    1.windows 10 64位
    2.Anaconda3-4.2.0-Windows-x86_64.exe (当时TF还不支持python3.6,又懒得在高版本的anaconda下配置多个Python环境,于是装了一个3-4.2.0(默认装python3.5),建议装anaconda3的最新版本,TF1.2.0版本已经支持python3.6!)
    3.TensorFlow1.1.0

    先贴代码:

    # -*- coding: utf-8 -*-
    """
    Created on Tue Jun 20 12:59:16 2017
    
    @author: ASUS
    """
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('MNIST_data/', one_hot = True) # 是一个tensorflow内部的变量
    sess = tf.InteractiveSession() # sess被注册为默认的session 
    
    #---------------第1/4步:定义算法公式-------------------
    # 各层参数初始化
    in_units = 784
    h1_units = 300
    W1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev = 0.1)) # 截断正态分布,剔除2倍标准差之外的数据
    b1 = tf.Variable(tf.zeros([h1_units]))
    W2 = tf.Variable(tf.zeros([h1_units, 10]))
    b2 = tf.Variable(tf.zeros([10]))
    
    # 定义计算图的输入
    x = tf.placeholder(tf.float32, [None, in_units])
    keep_prob = tf.placeholder(tf.float32)  # Droput的比例
    
    # 定义隐藏层结构 h = relu(W1*x + b1)
    hidden1 = tf.nn.relu(tf.matmul(x, W1) + b1)
    hidden1_drop = tf.nn.dropout(hidden1, keep_prob)
    # 定义输出层
    y = tf.nn.softmax(tf.matmul(hidden1_drop, W2) + b2)
    # 定义计算图的输入,y_ 是输入真实标签
    y_ = tf.placeholder(tf.float32, [None, 10])
    
    #---------------第2/4步:定义loss和优化器-------------------
    
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
                                    reduction_indices = [1]))
    train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy)
    
    
    #---------------第3/4步:训练步骤-------------------
    tf.global_variables_initializer().run()
    #迭代地执行训练操作
    for i in range(100):
        batch_xs, batch_ys = mnist.train.next_batch(100) # batch 数为100
        train_step.run({x: batch_xs, y_: batch_ys, keep_prob: 0.75})
    
    #---------------第4/4步:模型评估-------------------
    
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy)
    print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels,
                        keep_prob: 1.0}))
    

    整体分四步:
    1.定义算法公式
    2. 定义loss和优化器
    3. 训练
    4. 模型评估

    其中用到的函数总结(续上篇):
    1. sess = tf.InteractiveSession() 将sess注册为默认的session
    2. tf.placeholder() , Placeholder是输入数据的地方,也称为占位符,通俗的理解就是给输入数据(此例中的图片x)和真实标签(y_)提供一个入口,或者是存放地。(个人理解,可能不太正确,后期对TF有深入认识的话再回来改~~)
    3. tf.Variable() Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失
    4. tf.matmul() 矩阵相乘函数
    5. tf.reduce_mean 和tf.reduce_sum 是缩减维度的计算均值,以及缩减维度的求和
    6. tf.argmax() 是寻找tensor中值最大的元素的序号 ,此例中用来判断类别
    7. tf.cast() 用于数据类型转换
    ————————————–我是分割线(一)———————————–

    tf.random_uniform 生成均匀分布的随机数
    tf.train.AdamOptimizer() 创建优化器,优化方法为Adam(adaptive moment estimation,Adam优化方法根据损失函数对每个参数的梯度的一阶矩估计和二阶矩估计动态调整针对于每个参数的学习速率)
    tf.placeholder “占位符”,只要是对网络的输入,都需要用这个函数这个进行“初始化”
    tf.random_normal 生成正态分布
    tf.add 和 tf.matmul 数据的相加 、相乘
    tf.reduce_sum 缩减维度的求和
    tf.pow 求幂函数
    tf.subtract 数据的相减
    tf.global_variables_initializer 定义全局参数初始化
    tf.Session 创建会话.
    tf.Variable 创建变量,是用来存储模型参数的变量。是有别于模型的输入数据的
    tf.train.AdamOptimizer (learning_rate = 0.001) 采用Adam进行优化,学习率为 0.001
    ————————————–我是分割线(二)———————————–
    1. hidden1_drop = tf.nn.dropout(hidden1, keep_prob) 给 hindden1层增加Droput,返回新的层hidden1_drop,keep_prob是 Droput的比例
    2. mnist.train.next_batch() 来详细讲讲 这个函数。一句话概括就是,打乱样本顺序,然后按顺序读取batch_size 个样本 进行返回。
    具体看代码及其注释,首先要找到函数定义,在tensorflowcontriblearnpythonlearndatasets 下的mnist.py

      def next_batch(self, batch_size, fake_data=False, shuffle=True):
        """Return the next `batch_size` examples from this data set.
            可以接收三个参数,第二个没搞明白,第三个就是打乱样本顺序啦
            默认为打乱,不想打乱可以加上  shuffle=False
        """
        # 第一个if 是对 fake_data的,默认是不用管啦
        if fake_data:
          fake_image = [1] * 784
          if self.one_hot:
            fake_label = [1] + [0] * 9
          else:
            fake_label = 0
          return [fake_image for _ in xrange(batch_size)], [
              fake_label for _ in xrange(batch_size)
          ]
        # self._index_in_epoch是被初始化为 0
        # start就是此 batch第一个样本所在整个数据集的编号
        start = self._index_in_epoch
        # Shuffle for the first epoch
        # 这个if是在第一次的时候才运行,因为打乱样本只需要进行一次
        if self._epochs_completed == 0 and start == 0 and shuffle:
          perm0 = numpy.arange(self._num_examples)
          numpy.random.shuffle(perm0)
          self._images = self.images[perm0]
          self._labels = self.labels[perm0]
        # Go to the next epoch 当最后一个batch小于
        # 这个if是针对 样本数不能整除batch_size而设定的
        if start + batch_size > self._num_examples:
          # Finished epoch
          self._epochs_completed += 1
          # Get the rest examples in this epoch
          rest_num_examples = self._num_examples - start
          images_rest_part = self._images[start:self._num_examples]
          labels_rest_part = self._labels[start:self._num_examples]
          # Shuffle the data
          if shuffle:
            perm = numpy.arange(self._num_examples)
            numpy.random.shuffle(perm)
            self._images = self.images[perm]
            self._labels = self.labels[perm]
          # Start next epoch
          start = 0
          self._index_in_epoch = batch_size - rest_num_examples
          end = self._index_in_epoch
          images_new_part = self._images[start:end]
          labels_new_part = self._labels[start:end]
          return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
        else:
        # 其实,正常情况下,只执行下面两条语句,以上代码体现了此代码的健壮性
        # self._index_in_epoch其实就是为下一个batch的start做准备的 
        # end 就是本batch的 最后一个样本所在整个数据集的编号
        # 本batch的第一个样本的编号(start) 在上面已经赋值了
          self._index_in_epoch += batch_size  
          end = self._index_in_epoch 
          return self._images[start:end], self._labels[start:end]
  • 相关阅读:
    java判断字符串是否为数字
    门萨高智商者的集中营
    Android全局变量是用public&nbsp…
    oracle 关闭查询的进程
    oracle 常用参考
    oracle创建临时表
    透明网关设置
    透明网关diy
    又一个下拉菜单导航按钮
    数据库备份或导出时丢失主键的相关知识
  • 原文地址:https://www.cnblogs.com/TensorSense/p/7413315.html
Copyright © 2011-2022 走看看