zoukankan      html  css  js  c++  java
  • 神经网络5:循环神经网络1

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
    
    learn_rate = 0.001
    train_iters = 100000
    batch_size = 128
    n_inputs = 28
    n_steps = 28
    n_hidden_units = 128
    n_classes = 10
    
    in_x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
    in_label = tf.placeholder(tf.float32, [None, n_classes])
    
    weights = {
        'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
        'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
    }
    biases = {
        'in': tf.Variable(tf.constant(value=0.1, shape=[n_hidden_units, ])),
        'out': tf.Variable(tf.constant(value=0.1, shape=[n_classes, ]))
    }
    
    
    def RNN(inputs, weights, biases):
        inputs = tf.reshape(inputs, [-1, n_inputs])
        inputs_in = tf.matmul(inputs, weights['in']) + biases['in']
        inputs_in = tf.reshape(inputs_in, [-1, n_steps, n_hidden_units])
    
        # cell
    
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden_units, forget_bias=1.0, state_is_tuple=True)
        _init_state = lstm_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
        outputs, states = tf.nn.dynamic_rnn(lstm_cell, inputs_in, initial_state=_init_state, time_major=False)
        results = tf.matmul(states[1], weights['out']) + biases['out']
        return results
    
    
    pred = RNN(in_x, weights, biases)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=in_label))
    train = tf.train.AdamOptimizer(learn_rate).minimize(cost)
    
    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(in_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        step = 0
        while step * batch_size < train_iters:
            batch_x, batch_label = mnist.train.next_batch(batch_size)
            batch_x = batch_x.reshape([batch_size, n_steps, n_inputs])
            sess.run([train], feed_dict={in_x: batch_x, in_label: batch_label})
            if step % 20 == 0:
                print(sess.run(accuracy, feed_dict={in_x: batch_x, in_label: batch_label}))
            step += 1
  • 相关阅读:
    基于用例的工作量估计
    在xmlhttp中传递cookie和表单数据
    swfupload 上传SecurityError Error #2156 拓荒者
    为PetaPoco添加Fill方法 拓荒者
    【转】全面理解javascript的caller,callee,call,apply概念 拓荒者
    Dojo入门:dojo中的事件处理 拓荒者
    mstsc远程连接时的全屏快捷键 拓荒者
    关于DateTime对象序列化为Json之后的若干问题 拓荒者
    PetaPoco的默认映射 拓荒者
    DataTable序列化为JSON字符串 拓荒者
  • 原文地址:https://www.cnblogs.com/infoo/p/9509017.html
Copyright © 2011-2022 走看看