zoukankan      html  css  js  c++  java
  • 深度学习06

    循环神经网络

     

    代码

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    #tf.compat.v1.reset_default_graph()
    tf.compat.v1.disable_eager_execution()
    mnist = input_data.read_data_sets("./mnist_data", one_hot=True)
    
    trainimgs= mnist.train.images
    trainlabels = mnist.train.labels
    testimgs= mnist.test. images
    testlabels = mnist.test.labels
    ntrain,ntest,dim, nclasses= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
    
    
    diminput= 28
    dimhidden = 128
    dimoutput = nclasses
    nsteps= 28
    weights = {
        'hidden': tf.Variable(tf.compat.v1.random_normal([diminput, dimhidden])),
        'out': tf.Variable(tf.compat.v1.random_normal([dimhidden,dimoutput]))
    }
    biases = {
        'hidden': tf.Variable(tf.compat.v1.random_normal([dimhidden])),
        'out': tf.Variable(tf.compat.v1.random_normal ([dimoutput]))
    }
    
    def _RNN(_X,_W,_b,_nsteps,_name):
        #[batchsize, nsteps,diminput]
        _X = tf.transpose(_X,[1,0,2])
        #[nsteps*batchsize,diminput]
        _X = tf.reshape(_X, [-1, diminput])
        _H = tf.matmul(_X, _W['hidden']) + _b['hidden']
        _Hsplit = tf.split(_H,_nsteps,0)
    
        print(_name)
        with tf.compat.v1.variable_scope(_name, reuse=tf.compat.v1.AUTO_REUSE) as scope:
            #变量共享
            scope.reuse_variables()
    
            lstm_cell= tf.compat.v1.nn.rnn_cell.BasicLSTMCell(dimhidden,forget_bias = 1.0)
            _LSTM_O, _LSTM_S = tf.compat.v1.nn.static_rnn(lstm_cell,_Hsplit, dtype = tf.float32)
        _O = tf.matmul(_LSTM_O[-1],_W['out']) + _b['out']
    
        return{
        'X': _X,'H': _H, 'Hsplit': _Hsplit,'LSTM_O':_LSTM_O,'LSTM_S':_LSTM_S,'O':_O
        }
    
    
    learning_rate = 0.001
    x= tf.compat.v1.placeholder(tf.float32,[None,nsteps,diminput])
    y= tf.compat.v1.placeholder(tf.float32,[None,dimoutput])
    myrnn =_RNN(x,weights,biases,nsteps,'basic')
    pred= myrnn['O']
    cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits (pred, y))
    optm = tf.compat.v1.train.GradientDescentOptimizer(learning_rate).minimize(cost)# Adam
    accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1),tf.argmax(y,1)),tf.float32))
    init=tf.compat.v1.global_variables_initializer()
    
    
    with tf.compat.v1.Session() as sess:
        # 初始化变量
        sess.run(init)
    
        # 开始训练
        for i in range(100):
            # 获取真实值
            batch_xs,batch_ys = mnist.train.next_batch(16)
            batch_xs=batch_xs.reshape((16,nsteps,diminput))
            # print(image.shape)
            # print(label.shape)
    
            _, loss_value, accuracy_value = sess.run([optm, cost, accr], feed_dict={x: batch_xs, y: batch_ys})
            print("第%d次的损失为%f,准确率为%f" % (i + 1, loss_value, accuracy_value))
  • 相关阅读:
    python小记(4)
    python小记(3)
    python小记(2)
    python小记(1)
    Linux学习
    plist文件的 偏好设置 存储与读取 自定义对象归档
    控制器创建的三种方式
    IOS应用启动过程
    pch文件中自定义log
    textLabel辅助试图及toolBar创建使用
  • 原文地址:https://www.cnblogs.com/MoooJL/p/14354549.html
Copyright © 2011-2022 走看看