zoukankan      html  css  js  c++  java
  • RNNCell使用

    Recap

    43-RNNCell使用-rnn.jpg

    input dim, hidden dim

    from tensorflow.keras import layers
    
    # $xw_{xh} + hw_{nn}$,3次
    cell = layers.SimpleRNNCell(3)
    cell.build(input_shape=(None, 4))
    
    cell.trainable_variables
    
    [<tf.Variable 'kernel:0' shape=(4, 3) dtype=float32, numpy=
     array([[-0.5311725 ,  0.7757399 , -0.19041312],
            [ 0.90420175, -0.14276218,  0.1546886 ],
            [ 0.81770146, -0.46731013, -0.05373603],
            [ 0.49086082,  0.10275221,  0.10146773]], dtype=float32)>,
     <tf.Variable 'recurrent_kernel:0' shape=(3, 3) dtype=float32, numpy=
     array([[ 0.7557267 , -0.58395827,  0.2964283 ],
            [-0.64145935, -0.56886935,  0.5147014 ],
            [-0.13193521, -0.5791204 , -0.8044953 ]], dtype=float32)>,
     <tf.Variable 'bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]
    

    SimpleRNNCell

    • (out,h_1 = call(x,h_0))
      • x: [b,seq len,word vec]

      • (h_0/h_1: [b,h dim])

      • out: [b,h dim]

    Single layer RNN Cell

    import tensorflow as tf
    
    x = tf.random.normal([4, 80, 100])
    ht0 = x[:, 0, :]
    
    cell = tf.keras.layers.SimpleRNNCell(64)
    
    out, ht1 = cell(ht0, [tf.zeros([4, 64])])
    
    out.shape, ht1[0].shape
    
    []
    
    
    
    
    
    (TensorShape([4, 64]), TensorShape([4, 64]))
    
    id(out), id(ht1[0])  # same id
    
    (4877125168, 4877125168)
    

    Multi-Layers RNN

    43-RNNCell使用-多层rnn.jpg

    x = tf.random.normal([4, 80, 100])
    ht0 = x[:, 0, :]
    
    cell = tf.keras.layers.SimpleRNNCell(64)
    cell2 = tf.keras.layers.SimpleRNNCell(64)
    state0 = [tf.zeros([4, 64])]
    state1 = [tf.zeros([4, 64])]
    
    out0, state0 = cell(ht0, state0)
    out2, state2 = cell2(out, state2)
    
    out2.shape, state2[0].shape
    
    (TensorShape([4, 64]), TensorShape([4, 64]))
    

    RNN Layer

    self.run = keras.Sequential([
        layers.SimpleRNN(units,dropout=0.5,return_sequences=Ture,unroll=True),
        layers.SimpleRNN(units,dropout=0.5,unroll=True)
    ])
    x=self.rnn(x)
    
  • 相关阅读:
    14、迭代器协议、生成器、装饰器
    13、文件处理
    12、内置函数
    11、函数(def)
    10、基本数据类型(set)
    9、循环语句
    8、基本数据类型(dict)
    7、基本数据类型(tuple)
    6、基本数据类型(list)
    5、基本数据类型(str)
  • 原文地址:https://www.cnblogs.com/nickchen121/p/10963544.html
Copyright © 2011-2022 走看看