zoukankan      html  css  js  c++  java
  • tensorflow学习笔记七----------RNN

    和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。

    1.将数据处理成序列化;

    2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;

    3.所有数据都处理完。

    导入数据

    import tensorflow as tf
    import from tensorflow.examples.tutorials.mnist import input_data
    import numpy as np
    import matplotlib.pyplot as plt
    print ("Packages imported")
    
    mnist = input_data.read_data_sets("data/", one_hot=True)
    trainimgs, trainlabels, testimgs, testlabels 
     = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels 
    ntrain, ntest, dim, nclasses 
     = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
    print ("MNIST loaded")

    将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;

    diminput  = 28
    dimhidden = 128
    dimoutput = nclasses
    nsteps    = 28
    weights = {
        'hidden': tf.Variable(tf.random_normal([diminput, dimhidden])), 
        'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
    }
    biases = {
        'hidden': tf.Variable(tf.random_normal([dimhidden])),
        'out': tf.Variable(tf.random_normal([dimoutput]))
    }

    定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O

    def _RNN(_X, _W, _b, _nsteps, _name):
        # 1. Permute input from [batchsize, nsteps, diminput] 
        #   => [nsteps, batchsize, diminput]
        _X = tf.transpose(_X, [1, 0, 2])
        # 2. Reshape input to [nsteps*batchsize, diminput] 
        _X = tf.reshape(_X, [-1, diminput])
        # 3. Input layer => Hidden layer
        _H = tf.matmul(_X, _W['hidden']) + _b['hidden']
        # 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data 
        _Hsplit = tf.split(0, _nsteps, _H) 
        # 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
        #    Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
        #    Only _LSTM_O will be used to predict the output. 
        with tf.variable_scope(_name) as scope:
            
            scope.reuse_variables()
            lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
            _LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
        # 6. Output
        _O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']    
        # Return! 
        return {
            'X': _X, 'H': _H, 'Hsplit': _Hsplit,
            'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O 
        }
    print ("Network ready")

    定义好RNN后,定义损失函数等

    learning_rate = 0.001
    x      = tf.placeholder("float", [None, nsteps, diminput])
    y      = tf.placeholder("float", [None, dimoutput])
    myrnn  = _RNN(x, weights, biases, nsteps, 'basic')
    pred   = myrnn['O']
    cost   = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y)) 
    optm   = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
    accr   = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
    init   = tf.global_variables_initializer()
    print ("Network Ready!")

    进行训练

    training_epochs = 5
    batch_size      = 16
    display_step    = 1
    sess = tf.Session()
    sess.run(init)
    print ("Start optimization")
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
     
        # Loop over all batches
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
            # Fit training using batch data
            feeds = {x: batch_xs, y: batch_ys}
            sess.run(optm, feed_dict=feeds)
            # Compute average loss
            avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
        # Display logs per epoch step
        if epoch % display_step == 0: 
            print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
            feeds = {x: batch_xs, y: batch_ys}
            train_acc = sess.run(accr, feed_dict=feeds)
            print (" Training accuracy: %.3f" % (train_acc))
            testimgs = testimgs.reshape((ntest, nsteps, diminput))
            feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
            test_acc = sess.run(accr, feed_dict=feeds)
            print (" Test accuracy: %.3f" % (test_acc))
    print ("Optimization Finished.")
  • 相关阅读:
    中断高深吗?不!和我一起了解它!(三)
    IIS7下uploadify上传大文件出现404错误
    初来博客园
    cxf3.x +spring 3.x(4.x)+ maven 发布webservice 服务
    angularjs + fis +modJS 对于支持amd规范的组建处理(PhotoSwipe 支持,百度webUpload支持)
    elasticsearch suggest 的几种使用completion 的基本 使用
    使用github+sublime+markdwon 写文章,写博客并发布到博客园
    小互联网公司
    linux pts
    linux添加用户例如oracle
  • 原文地址:https://www.cnblogs.com/xxp17457741/p/9483514.html
Copyright © 2011-2022 走看看