zoukankan      html  css  js  c++  java
  • LSTM

    MINIST的循环神经网络--LSTM

    首先加载数据
    然后构建模型

    首先设置训练的超参数,分别设置学习率,训练次数和每轮训练的数据大小
    定义输入数据及权重
    定义模型
    训练和评估模型

    #Inspired by https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py
    import tensorflow as tf
    from tensorflow.contrib import rnn
    
    import numpy as np
    from tensorflow.examples.tutorials.mnist import input_data
    
    # configuration
    #                        O * W + b -> 10 labels for each image, O[? 28], W[28 10], B[10]
    #                       ^ (O: output 28 vec from 28 vec input)
    #                       |
    #      +-+  +-+       +--+
    #      |1|->|2|-> ... |28| time_step_size = 28
    #      +-+  +-+       +--+
    #       ^    ^    ...  ^
    #       |    |         |
    # img1:[28] [28]  ... [28]
    # img2:[28] [28]  ... [28]
    # img3:[28] [28]  ... [28]
    # ...
    # img128 or img256 (batch_size or test_size 256)
    #      each input size = input_vec_size=lstm_size=28
    
    # configuration variables
    input_vec_size = lstm_size = 28
    time_step_size = 28
    
    batch_size = 128
    test_size = 256
    
    def init_weights(shape):
        return tf.Variable(tf.random_normal(shape, stddev=0.01))
    
    
    def model(X, W, B, lstm_size):
        # X, input shape: (batch_size, time_step_size, input_vec_size)
        XT = tf.transpose(X, [1, 0, 2])  # permute time_step_size and batch_size
        # XT shape: (time_step_size, batch_size, input_vec_size)
        XR = tf.reshape(XT, [-1, lstm_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)
        # XR shape: (time_step_size * batch_size, input_vec_size)
        X_split = tf.split(XR, time_step_size, 0) # split them to time_step_size (28 arrays)
        # Each array shape: (batch_size, input_vec_size)
    
        # Make lstm with lstm_size (each input vector size)
        lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)
    
        # Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)
        outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32)
    
        # Linear activation
        # Get the last output
        return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the stat
    
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
    trX = trX.reshape(-1, 28, 28)
    teX = teX.reshape(-1, 28, 28)
    
    X = tf.placeholder("float", [None, 28, 28])
    Y = tf.placeholder("float", [None, 10])
    
    # get lstm_size and output 10 labels
    W = init_weights([lstm_size, 10])
    B = init_weights([10])
    
    py_x, state_size = model(X, W, B, lstm_size)
    
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
    train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
    predict_op = tf.argmax(py_x, 1)
    
    session_conf = tf.ConfigProto()
    session_conf.gpu_options.allow_growth = True
    
    # Launch the graph in a session
    with tf.Session(config=session_conf) as sess:
        # you need to initialize all variables
        tf.global_variables_initializer().run()
    
        for i in range(100):
            for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size)):
                sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})
    
            test_indices = np.arange(len(teX))  # Get A Test Batch
            np.random.shuffle(test_indices)
            test_indices = test_indices[0:test_size]
    
            print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                             sess.run(predict_op, feed_dict={X: teX[test_indices]})))
    

    0 0.69140625
    1 0.81640625
    2 0.88671875
    3 0.921875
    4 0.91015625
    5 0.953125
    6 0.9453125
    7 0.95703125
    8 0.96484375
    9 0.953125
    10 0.9765625
    11 0.9609375
    12 0.9609375
    13 0.93359375
    14 0.97265625
    15 0.984375
    16 0.98828125
    17 0.97265625
    18 0.9765625

  • 相关阅读:
    Mvc+三层(批量添加、删除、修改)
    js中判断复选款是否选中
    EF的优缺点
    Git tricks: Unstaging files
    Using Git Submodules
    English Learning
    wix xslt for adding node
    The breakpoint will not currently be hit. No symbols have been loaded for this document."
    Use XSLT in wix
    mfc110ud.dll not found
  • 原文地址:https://www.cnblogs.com/Ann21/p/10484547.html
Copyright © 2011-2022 走看看