zoukankan      html  css  js  c++  java
  • 【6-1】RNN循环神经网络

    一、问题:如何利用神经网络处理序列问题(语音、文本)?

    在MNIST手写数字识别中,输入一张图片,得到一个结果,输入另一张图片,得到另一个结果,输入的样本是相互独立的,输出的结果之间也不会相互影响。也就是说,这时处理的数据是IID(独立同分布)数据,但序列类的数据却不满足IID特征,所以RNN出场了。

    二、RNN的结构

    看到hello,wor__!你肯定会轻而易举地预测出后两个字符为ld。

    RNN结构如下:

    左侧:x是输入,s相当于隐藏层,o是输出。U、V、W都是权值矩阵。为什么称之为循环那?因为隐藏层的输出不光传给了下一个节点,也传给了它本身。展开如右:t代表时刻,st是时刻t时的记忆,st不仅与t时刻的输入有关,还与上一个时刻的记忆st-1有关,故st=f(Uxt+Wst-1),ot是t时刻的输出,比如是预测下个词的时候,可能是softmax输出的每个候选词的概率。不仅与当前的输入有关,还与之前的记忆有关。图中也可以发现:W、U、V没有变过,所以RNN的权值矩阵是共享的,这样就大大减少了训练的参数。

    这两个图都是一个意思:不仅与当前输入有关系,还有上一时刻的记忆有关系,就和电容这种记忆性元件是一个道理。但是,RNN的记忆是有限的,它不可能把所有的都记住。所以,LSTM又出来解决这个问题了。

     三、LSTM(Long Short Term Memory)网络

    在看电影的时候,情节发展往往要根据之前的细节来推断,因为作者往往藏了伏笔。但RNN网络的记忆细胞随着时间的推移,有些内容它就会忘掉,记不住之前的伏笔。但是LSTM的记忆细胞会把该记住的记住,把该忽略的忽略。

    上图是简单的RNN网络结构。

    这是LSTM结构。它对状态是否参与输入以及状态的更新做了灵活的选择,也就是它可以过滤掉不想再记住的东西,还可以再往里加一些新的东西。

    这种结构的核心思想是引入了一个叫“细胞状态(cell state)”的连接,这个细胞状态用来存放想要记忆的东西,同时在里面加了3个门。

    细胞状态Ct在行走的过程中,总会遇到各种操作。也许乘,也许加。这些都是它走过了一扇又一扇门实现的。

    第一个要过的忘记门:把以前的状态忘记,即决定丢弃什么信息。

    经过这一步,就选出来了那些不想要的东西,ft的值在0—1之间,0表示完全舍弃,1表示完全保留。

    下一个要过输入门:决定加入什么新的状态,即更新细胞状态。

    然后就是细胞状态的更新。 

     

    最后过输出门:把更新后的状态和输入一起输出。

    四、手写数字识别参考小程序:

     1 import tensorflow as tf
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 #载入数据集
     5 mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
     6 
     7 # 输入图片是28*28
     8 n_inputs = 28 #输入一行,一行有28个数据,输入神经元,每次输入一行
     9 max_time = 28 #一共28行,每次输入一行,一共需要输入28次
    10 lstm_size = 100 #隐层单元
    11 n_classes = 10 # 10个分类
    12 batch_size = 50 #每批次50个样本
    13 n_batch = mnist.train.num_examples // batch_size #计算一共有多少个批次
    14 
    15 #这里的none表示第一个维度可以是任意的长度
    16 x = tf.placeholder(tf.float32,[None,784])
    17 #正确的标签
    18 y = tf.placeholder(tf.float32,[None,10])
    19 
    20 #初始化权值
    21 weights = tf.Variable(tf.truncated_normal([lstm_size, n_classes], stddev=0.1))
    22 #初始化偏置值
    23 biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
    24 
    25 
    26 #定义RNN网络
    27 def RNN(X,weights,biases):
    28     # inputs=[batch_size, max_time, n_inputs]
    29     inputs = tf.reshape(X,[-1,max_time,n_inputs])  #X由50*784转化成50*28*28
    30     #定义LSTM基本CELL
    31     lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size)
    32     # final_state[0]是cell state
    33     # final_state[1]是hidden_state
    34     outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
    35     results = tf.nn.softmax(tf.matmul(final_state[1],weights) + biases)
    36     return results
    37     
    38     
    39 #计算RNN的返回结果
    40 prediction= RNN(x, weights, biases)  
    41 #损失函数
    42 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
    43 #使用AdamOptimizer进行优化
    44 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    45 #结果存放在一个布尔型列表中
    46 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    47 #求准确率
    48 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#把correct_prediction变为float32类型
    49 #初始化
    50 init = tf.global_variables_initializer()
    51 
    52 with tf.Session() as sess:
    53     sess.run(init)
    54     for epoch in range(6):
    55         for batch in range(n_batch):
    56             batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
    57             sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
    58         
    59         acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
    60         print ("Iter " + str(epoch) + ", Testing Accuracy= " + str(acc))
     1 Extracting MNIST_data/train-images-idx3-ubyte.gz
     2 Extracting MNIST_data/train-labels-idx1-ubyte.gz
     3 Extracting MNIST_data/t10k-images-idx3-ubyte.gz
     4 Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
     5 Iter 0, Testing Accuracy= 0.7221
     6 Iter 1, Testing Accuracy= 0.8016
     7 Iter 2, Testing Accuracy= 0.8763
     8 Iter 3, Testing Accuracy= 0.9103
     9 Iter 4, Testing Accuracy= 0.9223
    10 Iter 5, Testing Accuracy= 0.9311

     关于dynamic_rnn定义中的参数:

    • cell:生成好的cell类对象。
    • inputs:输入数据,是一个张量,一般是三维张量:[batch_size,max_time,...],其中batch_size表示一次的批次数量,max_time表示时间序列总数,后面是具体数据。
    • sequence_length:每一个输入的序列长度。

    返回值:一个是结果,一个是cell状态,结果是以[batch_size,max_time,...]形式的张量。

    结论上来说,如果cell为LSTM,那 state是个tuple,分别代表ht和ct,其中ht与outputs中的对应的最后一个时刻的输出相等,假设state形状为[ 2,batch_size, cell.output_size ],outputs形状为 [ batch_size, max_time, cell.output_size ],那么state[ 1, batch_size, : ] == outputs[ batch_size, -1, : ]。参考:https://blog.csdn.net/u010960155/article/details/81707498

    2019-06-18 21:12:44

  • 相关阅读:
    hosts 文件妙用
    asp.net 各种路径
    正则表达式
    int.Parse()、int.TryParse()和Convert.ToInt32()的区别
    总结.NET 中什么时候用 Static
    利用.net的内部机制在asp.net中实现身份验证
    server.transfer 用法
    sql server Datetime格式转换
    如果在代码中使用JS
    js 添加广告
  • 原文地址:https://www.cnblogs.com/direwolf22/p/11008465.html
Copyright © 2011-2022 走看看