zoukankan      html  css  js  c++  java
  • day18-RNN实现手写数字识别

    
    # coding=utf-8
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    from tensorflow.contrib import rnn
    
    def weight_variable(shape):
        """
        权重初始化函数
        :param shape:
        :return:
        """
        weight = tf.Variable(tf.random_normal(shape,seed=0.0,stddev=1.0))
        return weight
    
    def bias_variable(shape):
        """
        偏置初始化函数
        :param shape:
        :return:
        """
        bias = tf.Variable(tf.random_normal(shape, seed=0.0, stddev=1.0))
        return bias
    
    
    def model():
    
        # 初始输入的是50,784,需要转变成 50,28,28,接着变为28,50,28,最后变为可计算的28*50,28
        x = tf.placeholder(tf.float32,[None,784])
        y = tf.placeholder(tf.float32,[None,10])
        x_reshape = tf.reshape(x,[-1,28,28])
        x_reshape = tf.transpose(x_reshape,[1,0,2])
        x_reshape = tf.reshape(x_reshape,[-1,28])
    
        # 初始化隐层的权重和偏置
        weight = weight_variable([28,128])
        bias = bias_variable([128])
        # 初始化最后一层的权重和偏置
        weight_final = weight_variable([128, 10])
        bias_final = bias_variable([10])
    
        # 得出隐层的结果
        h = tf.matmul(x_reshape,weight) + bias
        # 因为是全部进行计算了,所以还需要进行切分,切分为28份,相当于将batch_size个样本的28行分出来
        h_split = tf.split(h,28,0)
    
        # lstm也就是RNN的关键一层
        lstm_cell = rnn.BasicLSTMCell(128)
        # tensorflow 版本 < 1.0使用如下
        # lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(128, forget_bias=1.0)
        lstm_o, lstm_s = rnn.static_rnn(lstm_cell, h_split, dtype=tf.float32)
    
        predict = tf.matmul(lstm_o[-1], weight_final) + bias_final
    
        return x,y,predict
    
    def RnnCon():
    
        # 1、准备数据
        mnist = input_data.read_data_sets("../data/day06/",one_hot=True)
    
        # 2、模型的建立
        x,y_true,y_predict = model()
    
        # 3、模型参数计算
        with tf.variable_scope("model_soft_corss"):
            # 计算交叉熵损失
            softmax = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)
            # 计算损失平均值
            loss = tf.reduce_mean(softmax)
    
        # 4、梯度下降(反向传播算法)优化模型
        with tf.variable_scope("model_better"):
            tarin_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
    
        # 5、计算准确率
        with tf.variable_scope("model_acc"):
            # 计算出每个样本是否预测成功,结果为:[1,0,1,0,0,0,....,1]
            equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
    
            # 计算出准确率,先将预测是否成功换为float可以得到详细的准确率
            acc = tf.reduce_mean(tf.cast(equal_list, tf.float32))
    
        # 6、准备工作
        # 定义变量初始化op
        init_op = tf.global_variables_initializer()
        # 定义哪些变量记录
        tf.summary.scalar("losses", loss)
        tf.summary.scalar("acces", acc)
        merge = tf.summary.merge_all()
    
        # 开启会话运行
        with tf.Session() as sess:
            # 变量初始化
            sess.run(init_op)
    
            # 开启记录
            filewriter = tf.summary.FileWriter("../summary/day08/", graph=sess.graph)
    
            for i in range(2000):
                # 准备数据
                mnist_x, mnist_y = mnist.train.next_batch(16)
    
                # 开始训练
                sess.run([tarin_op], feed_dict={x: mnist_x, y_true: mnist_y})
    
                # 得出训练的准确率,注意还需要将数据填入
                print("第%d次训练,准确率为:%f" % ((i + 1), sess.run(acc, feed_dict={x: mnist_x, y_true: mnist_y})))
    
                # 写入每步训练的值
                summary = sess.run(merge, feed_dict={x: mnist_x, y_true: mnist_y})
                filewriter.add_summary(summary, i)
    
        return None
    
    
    
    
    
    if __name__ == '__main__':
        RnnCon()
    
    
  • 相关阅读:
    Item02.多态 Polymorphism
    使用Singleton需要考虑内存释放
    Item08. 多级指针(Pointers to Pointers)
    Item01: 数据提取(Data Abstraction)
    Item 05. 引用(References Are Aliases, Not Pointers)
    华为3Com Quidway 2116SI
    DLink DES1226G 一款不错的中端交换机
    郁闷
    一些VLAN学习资料
    有个好心情才会有好的状态
  • 原文地址:https://www.cnblogs.com/wuren-best/p/14366310.html
Copyright © 2011-2022 走看看