zoukankan      html  css  js  c++  java
  • tensorflow中的lstm的state

       

    考虑 state_is_tuple

       

    Output, new_state = cell(input, state)

       

    state其实是两个 一个 c state,一个m(对应下图的hidden 或者h) 其中m(hidden)其实也就是输出

       

       

       

       

    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple

    else array_ops.concat(1, [c, m]))

    return m, new_state

       

       

    def basic_rnn_seq2seq(

    encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None):

    with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):

    _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype)

    return rnn_decoder(decoder_inputs, enc_state, cell)

       

       

    def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,

    scope=None):

    with variable_scope.variable_scope(scope or "rnn_decoder"):

    state = initial_state

    outputs = []

    prev = None

    for i, inp in enumerate(decoder_inputs):

    if loop_function is not None and prev is not None:

    with variable_scope.variable_scope("loop_function", reuse=True):

    inp = loop_function(prev, i)

    if i > 0:

    variable_scope.get_variable_scope().reuse_variables()

    output, state = cell(inp, state)

    outputs.append(output)

    if loop_function is not None:

    prev = output

    return outputs, state

       

       

    这里decoder用了encoder的最后一个state 作为输入

       

    然后输出结果是decoder过程最后的state 加上所有ouput的集合(也就是hidden的集合)

    注意ouputs[-1]其实数值和state里面的m是一致的

    当然有可能后面outputs dynamic rnn 会补0

       

    encode_feature, state = melt.rnn.encode(

    cell,

    inputs,

    seq_length,

    encode_method=0,

    output_method=3)

       

    encode_feature.eval()

    array([[[ 4.27834410e-03, 1.45841937e-03, 1.25767402e-02,
    5.00775501e-03],
    [ 6.24437723e-03, 2.60074623e-03, 2.32168660e-02,
    9.47457738e-03],
    [ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,
    -5.71310846e-03],

    [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00]]], dtype=float32)

       

       

    state[1].eval()

    array([[ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,
    -5.71310846e-03]], dtype=float32
    )

       

       

       

  • 相关阅读:
    听说在新的一年里你的证书过期了
    css 清楚浮动的8种方式
    Majority Element:主元素
    HDOJ 5296 Annoying problem LCA+数据结构
    hdu 5318 The Goddess Of The Moon 矩阵高速幂
    友盟页面统计
    用html语言写一个功课表
    苹果新的编程语言 Swift 语言进阶(二)--基本数据类型
    Atitit.mssql 数据库表记录数and 表体积大小统计
    jeecms 代码生成 Tools
  • 原文地址:https://www.cnblogs.com/rocketfan/p/6257137.html
Copyright © 2011-2022 走看看