zoukankan      html  css  js  c++  java
  • 学习进度笔记18

    观看Tensorflow案例实战视频课程18 训练RNN网络

    def _RNN(_X,_W,_b,_nsteps,_name):
        #1.Permute input from [batchsize,nsteps,diminput]
        #  =>[nsteps,batchsize,diminput]
        _X=tf.transpose(_X,[1,0,2])
        #2.Reshape input to [nsteps*batchsize,diminput]
        _X=tf.reshape(_X,[-1,diminput])
        #3.Input layer => Hidden layer
        _H=tf.matmul(_X,_W['hidden'])+_b['hidden']
        #4.Splite data to 'nsteps' chunks. An i_th chunck indicates i_th batch data
        _Hsplit=tf.split(0,_nsteps,_H)
        #5.Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
        #  Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
        #  Only _LSTM_O will be used to Predict the output.
        with tf.variable_scope(_name) as scope:
            scope.reuse_variables()
            lstm_cell=tf.nn.run_cell.BasicLSTMCell(dimhidden,forget_bias=1.0)
            _LSTM_O,_LSTM_S=tf.nn.rnn(lstm_cell,_Hsplit,dtype=tf.float32)
        #6.Output
        _O=tf.matmul(_LSTM_O[-1],_W['out'])+_b['out']
        #Return!
        return{
            'X':_X,'H':_H,'Hsplit':_Hsplit,
            'LSTM_O':_LSTM_O,'LSTM_S':_LSTM_S,'O':_O
        }
    print("Network ready")
    
    learning_rate=0.001
    x=tf.placeholder("float",[None,nsteps,diminput])
    y=tf.placeholder("float",[None,dimoutput])
    myrnn=_RNN(x,weights,biases,nsteps,'basic')
    #myrnn=_RNN(x,weights,biases,nsteps,'basic1')
    pred=myrnn['O']
    cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred,y))
    optm=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)#Adam Optimizer
    accr=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1),tf.argmax(y,1)),tf.float32))
    init=tf.global_variables_initializer()
    print("Network Ready!")
    
    training_epochs=5
    batch_size=16
    display_step=1
    sess=tf.Session()
    sess.run(init)
    for epoch in range(training_epochs):
        avg_cost=0
        #total_batch=int(mnist.train.num_examples/batch_size)
        total_batch=100
        #Loop over all batches
        for i in range(total_batch):
           batch_xs,batch_ys=mnist.train.next_batch(batch_size)
           batch_xs=batch_xs.reshape((batch_size,nsteps,diminput))
           #Fit training using batch data
           feeds={x:batch_xs,y:batch_ys}
           sess.run(optm,feed_dict=feeds)
           #Compute average loss
           avg_cost+=sess.run(cost,feed_dict=feeds)/total_batch
        #Display logs per epoch step
        if epoch % display_step==0:
            print("Epoch:%03d/%03d cost:%.9f" % (epoch,training_epochs,avg_cost))
            feeds={x:batch_xs,y:batch_ys}
            train_acc=sess.run(accr,feed_dict=feeds)
            print("Training accuracy:%.3f" % (train_acc))
            testimgs=testimgs.reshape((ntest,nsteps,diminput))
            feeds={x:testimgs,y:testlabels,istate:np.zeros((intest,2*dimhidden))}
            test_acc=sess.run(accr,feed_dict=feeds)
            print("Test accuracy:%.3f" % (test_acc))
    print("Optimization Finished.")
  • 相关阅读:
    mysql数据库主从同步复制原理
    NoSQL
    Mysqldump参数大全
    MySQL Show命令的使用
    学习shell脚本之前的基础知识
    详解MySQL大表优化方案
    sql索引的优缺点
    [C#] 取得每月第一天和最後一天、某月总天数
    Easy ui DateBox 控件格式化显示操作
    StudioStyle 使用 厌倦了默认的Visutal Studio样式了,到这里找一个酷的试试
  • 原文地址:https://www.cnblogs.com/zql-42/p/14631137.html
Copyright © 2011-2022 走看看