zoukankan      html  css  js  c++  java
  • tensorflow中的lstm的state

       

    考虑 state_is_tuple

       

    Output, new_state = cell(input, state)

       

    state其实是两个 一个 c state,一个m(对应下图的hidden 或者h) 其中m(hidden)其实也就是输出

       

       

       

       

    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple

    else array_ops.concat(1, [c, m]))

    return m, new_state

       

       

    def basic_rnn_seq2seq(

    encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None):

    with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):

    _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype)

    return rnn_decoder(decoder_inputs, enc_state, cell)

       

       

    def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,

    scope=None):

    with variable_scope.variable_scope(scope or "rnn_decoder"):

    state = initial_state

    outputs = []

    prev = None

    for i, inp in enumerate(decoder_inputs):

    if loop_function is not None and prev is not None:

    with variable_scope.variable_scope("loop_function", reuse=True):

    inp = loop_function(prev, i)

    if i > 0:

    variable_scope.get_variable_scope().reuse_variables()

    output, state = cell(inp, state)

    outputs.append(output)

    if loop_function is not None:

    prev = output

    return outputs, state

       

       

    这里decoder用了encoder的最后一个state 作为输入

       

    然后输出结果是decoder过程最后的state 加上所有ouput的集合(也就是hidden的集合)

    注意ouputs[-1]其实数值和state里面的m是一致的

    当然有可能后面outputs dynamic rnn 会补0

       

    encode_feature, state = melt.rnn.encode(

    cell,

    inputs,

    seq_length,

    encode_method=0,

    output_method=3)

       

    encode_feature.eval()

    array([[[ 4.27834410e-03, 1.45841937e-03, 1.25767402e-02,
    5.00775501e-03],
    [ 6.24437723e-03, 2.60074623e-03, 2.32168660e-02,
    9.47457738e-03],
    [ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,
    -5.71310846e-03],

    [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00]]], dtype=float32)

       

       

    state[1].eval()

    array([[ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,
    -5.71310846e-03]], dtype=float32
    )

       

       

       

  • 相关阅读:
    Sample XPS Documents Download
    触发器中的inserted表和deleted表
    Using freewheel.labs.autodesk.com to auto generate preview images of DWF files on your web site
    解除SQL对组件"Ad Hoc Distributed Queries"的"STATEMENT'OpenRowset OpenDatasource"的访问
    读写xps
    XPS文件,在Windows XP下的打开查看阅读和打印方法。
    Learning to Reference Inserted and Deleted Tables
    Get value from updated, inserted and deleted
    Reinstall Microsoft Helper Viewer
    如何查找文件的IFilter
  • 原文地址:https://www.cnblogs.com/rocketfan/p/6257137.html
Copyright © 2011-2022 走看看