zoukankan      html  css  js  c++  java
  • 循环神经网络入门的一个总结

    1、常用神经网络结构中有个叫RNN的,即循环神经网络。

    假设有n个cell,从第一个cell开始说起。

    state 0+time0 进入cell ,cell处理,处理后的结果,可以分成两个相同的,一个用来输出该层的输出,另一个送给下一个cell,当然,分成两个相同的之后,想怎么变就怎么变。

    对第二个cell来说,第一次的输出和当前时间,是他的输出,就这样,上一个输出始终作为下一个的输入。

    深层循环是什么意思呢,就是我们不是两个输出吗,一个用来给下一个cell,那另一个是不是就可以在这个cell的基础上,往深处传播,当然,这个深处在自然语言处理中,一般用于处理单词的词向量。

    2、在循环神经网络中,有个结构叫长短期记忆网络(LSTM)

    为什么会有个长短期记忆网络呢,因为循环网络实在太长了,但是前面的很多东西我们根本用不上,所以呢 ,对于某些cell来说,我就要忘记一些东西。所以就有了长短期记忆网络。长短期记忆网络的总体结构和上面循环神经网络很像,不过呢,他会多了一个遗忘门,对于以往输入的数据,他会根据当前数据选择性的删除。

    #定义LSTM结构
    lstm = rnn_cell.BasicLSTMCell(lstm_hidden_size)
    #状态初始化,初始化为一个全0数组
    state = lstm.zero_state(batch_size,tf.float32)
    #loss初始化
    loss = 0.0
    #规定一个长度,在这个长度内进行运算,太长了对计算机负担太大并且没有必要,这个for语句,实际上就是循环运算, 每次get——variable——socpe,再将lstm算得的数据输出,最后加上。
    for i in range(num_steps):
        if i > 0 : tf.get_variable_scope().reuse_variables()
        lstm_output, state = lstm(current_input, state)
        final_output = fully_connected(lstm_output)
        
        loss += calc_loss(final_output, expected_output)

    3、深层神经网络的代码基础

    lstm = rnn_cell.BasicLSTMCell(lstm_size)
    #实际上,我们发现,深层循环网络相比普通的,就是在原来的lstm基础上,以lstm*层数为参数,输入到新的函数:MultiRNNCell里面即可。
    stacked_lstm = rnn_cell.MultiRNNCell([lstm]*number_of_layers)
    
    state = stacked_lstm.zero_state(batch_size, tf.float32)
    
    for i in range(len(num_steps)):
        if i > 0: tf.get_variable_scope().reuse_variables()
        stacked_lstm_output, state = stacked_lstm(current_input,state)
        final_output = fully_connected(stacked_lstm_output)
        loss += calc_loss(final_output,expected_output)

    4、循环神经网络的dropout

    lstm = rnn_cell.BasicLSTMCell(lstm_size)
    #dropout实际上是将原来的lstm输入函数中,生成一个dropout的latm
    dropout_lstm = tf.nn.rnn_cell.DropoutWrapper(lstm,output_keep_prob=0.5)
    #实际上,我们发现,深层循环网络相比普通的,就是在原来的lstm基础上,以dropout_lstm*层数为参数,输入到新的函数:MultiRNNCell里面即可。
    stacked_lstm = rnn_cell.MultiRNNCell([dropout_lstm]*number_of_layers)
    
    state = stacked_lstm.zero_state(batch_size, tf.float32)
    
    for i in range(len(num_steps)):
        if i > 0: tf.get_variable_scope().reuse_variables()
        stacked_lstm_output, state = stacked_lstm(current_input,state)
        final_output = fully_connected(stacked_lstm_output)
        loss += calc_loss(final_output,expected_output)
  • 相关阅读:
    Spring boot unable to determine jdbc url from datasouce
    Unable to create initial connections of pool. spring boot mysql
    spring boot MySQL Public Key Retrieval is not allowed
    spring boot no identifier specified for entity
    Establishing SSL connection without server's identity verification is not recommended
    eclipse unable to start within 45 seconds
    Oracle 数据库,远程访问 ora-12541:TNS:无监听程序
    macOS 下安装tomcat
    在macOS 上添加 JAVA_HOME 环境变量
    Maven2: Missing artifact but jars are in place
  • 原文地址:https://www.cnblogs.com/baochen/p/8997731.html
Copyright © 2011-2022 走看看