zoukankan      html  css  js  c++  java
  • 从rnn到lstm,再到seq2seq(一)

    rnn的的公式很简单:

    对于每个时刻,输入上一个时刻的隐层s和这个时刻的文本x,然后输出这个时刻的隐层s。对于输出的隐层s 做个ws+b就是这个时刻的输出y。

    tf.scan(fn, elems, initializer) # scan operation
    
    def fn(st_1, xt): # recurrent function
    
        st = f(st_1, xt)
        return st

    rnn的实现:

    def step(hprev, x):
        # initializer
        xav_init = tf.contrib.layers.xavier_initializer
        # params
        W = tf.get_variable('W', shape=[state_size, state_size], initializer=xav_init())
        U = tf.get_variable('U', shape=[state_size, state_size], initializer=xav_init())
        b = tf.get_variable('b', shape=[state_size], initializer=tf.constant_initializer(0.))
        # current hidden state
        h = tf.tanh(tf.matmul(hprev, W) + tf.matmul(x,U) + b)
        return h
    states = tf.scan(step, 
                tf.transpose(rnn_inputs, [1,0,2]),
                initializer=init_state) 

    lstm只是网络结构上个对rnn进行改进,它同时增加一个单元叫做state状态,每个lstm有个hidden和一个state。

    下面图中h就是隐层,下面图中的c就是状态。首先根据这个时刻的输入x和上个时刻的隐层算出三个门,f(forget),i(input),o(ouput)

    激活函数是sigmoid函数,输出0或者1。算出来的f门是来控制上个状态多少被忘记。算出来的i门来控制这个时刻状态的多少被输入。

    本时刻的状态由这个时刻的输入x和上个时刻的隐层算出然后用tan函数激活(对应第四行公式)。

    本时刻隐层的输出h是由本时刻的状态用tan来激活,然后乘以输出门

     看看lstm的实现:

                def step(prev, x):
                    # gather previous internal state and output state
                    st_1, ct_1 = tf.unpack(prev)
                    ####
                    # GATES
                    #
                    #  input gate
                    i = tf.sigmoid(tf.matmul(x,U[0]) + tf.matmul(st_1,W[0]))
                    #  forget gate
                    f = tf.sigmoid(tf.matmul(x,U[1]) + tf.matmul(st_1,W[1]))
                    #  output gate
                    o = tf.sigmoid(tf.matmul(x,U[2]) + tf.matmul(st_1,W[2]))
                    #  gate weights
                    g = tf.tanh(tf.matmul(x,U[3]) + tf.matmul(st_1,W[3]))
                    ###
                    # new internal cell state
                    ct = ct_1*f + g*i
                    # output state
                    st = tf.tanh(ct)*o
                    return tf.pack([st, ct])
                ###
                # here comes the scan operation; wake up!
                #   tf.scan(fn, elems, initializer)
                states = tf.scan(step, 
                        tf.transpose(rnn_inputs, [1,0,2]),
                        initializer=init_state)

    在来看下gru

    gru里面没有state这个东西,它有两个门,一个是z,遗忘门,一个是r,就是reset门

    跟lstm。算出遗忘门,来控制上个时刻的多少隐层被遗忘,另一半(1-z)就是本时刻多少隐层被输入。

    本时刻多少隐层,跟lstm也很相似,只是在上个时刻的h上加了个reset门,就是:根据上个时刻的h加上reset门,和本时刻的输入x,通过tan来激活

    看看gru的实现:

      def step(st_1, x):
                    ####
                    # GATES
                    #
                    #  update gate
                    z = tf.sigmoid(tf.matmul(x,U[0]) + tf.matmul(st_1,W[0]))
                    #  reset gate
                    r = tf.sigmoid(tf.matmul(x,U[1]) + tf.matmul(st_1,W[1]))
                    #  intermediate
                    h = tf.tanh(tf.matmul(x,U[2]) + tf.matmul( (r*st_1),W[2]))
                    ###
                    # new state
                    st = (1-z)*h + (z*st_1)
                    return st
                ###
                # here comes the scan operation; wake up!
                #   tf.scan(fn, elems, initializer)
                states = tf.scan(step, 
                        tf.transpose(rnn_inputs, [1,0,2]),
                        initializer=init_state)

    参考文章:

    http://colah.github.io/posts/2015-08-Understanding-LSTMs/

    http://suriyadeepan.github.io/2017-02-13-unfolding-rnn-2/

    https://github.com/suriyadeepan/rnn-from-scratch

    http://karpathy.github.io/2015/05/21/rnn-effectiveness/

  • 相关阅读:
    git代码回退
    7 用两个栈实现队列
    《Java并发编程实战》学习笔记
    226. Invert Binary Tree
    Interface与abstract类的区别
    Override和Overload的区别
    Java面向对象的三个特征与含义
    String、StringBuffer与StringBuilder的区别
    Hashcode的作用
    Object有哪些公用方法
  • 原文地址:https://www.cnblogs.com/dmesg/p/6882664.html
Copyright © 2011-2022 走看看