zoukankan      html  css  js  c++  java
  • Keras Mask 实验总结 (原创)

    Conclusion:

    1. Mask 是创造了一个 mask 矩阵,随着每一层的结果 tensor 一起逐层传递,如果之后某一层不能接受 mask 矩阵则会报错
    2. Embedding, mask_zero 有效
    3. Concatenate, Dense 层之前可以有 Masking 层, 虽然从 tensor output 输出来看似乎 mask 矩阵没有作用,但是相应 mask 矩阵会继续向下传递,影响后边的层
    4. Mask 主要作用于 RNN 层,会忽略掉相应的 timestep,在 tensor output 的表现为:被 mask 的 timestep 结果为 0 或者与之前时间步结果相同
    5. Concatenate 之前如果 一个输入矩阵的某个 timestep 被 mask 了,整个输出矩阵的那个 timestep 都会被 mask
    6. 不要重复调用 Masking 层,因为会重新定义 mask 矩阵。尤其是在 Embedding 层后 mask 的 timestep 并不为 0,会使 mask_value 不全部匹配

    Experimental:

    模型部分代码 (用无序编号代替缩进):

    def rnn_model(x_train, y_train):  

    # Inputs

    num = Input(shape=(x_train[0].shape[1], x_train[0].shape[2]))

    version = Input(shape=(x_train[1].shape[1], x_train[1].shape[2]))

    missing = Input(shape=(x_train[2].shape[1], x_train[2].shape[2]))

    inputs = [num, version, missing]

    # Embedding for categorical variables

    reshape_version = Reshape(target_shape=(-1,))(version)

    embedding_version = Embedding(180, 2, input_length=x_train[1].shape[1] * x_train[1].shape[2], mask_zero=True, name='M_version')(reshape_version)

    reshape_missing = Reshape(target_shape=(-1,))(missing)

    embedding_missing = Embedding(4, 1, input_length=x_train[1].shape[1] * x_train[1].shape[2], mask_zero=True, name='M_missing')(reshape_missing)

    num = Masking(mask_value=0, name='M_num')(num)

    # # # concatenate layer

    merge_ft = concatenate([num, embedding_version, embedding_missing], axis=-1, name='concate')

    # GRU with various length

    '''

    Do not use anymore mask layer, as a new layer will overwrite the mask tensor.

    As long as part of the timestep is masked, then the whole timestep is masked and won't be calculated

    '''

    # merge_ft = Dense(3, name='test')(merge_ft)

    gru_1 = GRU(3, return_sequences=True, name='gru_1')(merge_ft)

    gru_2 = GRU(3, return_sequences=True, name='gru_2')(gru_1)

    gru_3 = GRU(3, name='gru_3')(gru_2)

    dense_ft = Dense(2, name='dense_ft')(gru_3)

    outputs = Lambda(lambda x: K.tf.nn.softmax(x), name='outputs')(dense_ft)

    model = Model(inputs=inputs, outputs=outputs)

    adam = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=1e-6)

    model.compile(loss='categorical_crossentropy', optimizer=adam)

    return model

    测试部分代码
    if __name__ == '__main__':

    # for test mask

    # fake num with size 1*5*3

    num = [[[0,0,0],[1,2,3],[0,0,0],[1,2,3],[0,0,0]]]

    num = np.array(num)

    c1 = [[[0],[1],[0],[1],[0]]]

    c1 = np.array(c1)

    c2 = [[[0],[1],[0],[1],[0]]]

    c2 = np.array(c2)

    y = [[0, 1]]

    y = np.array(y)

    x = [num, c1, c2]

    model = rnn_model(x, y)

    layer_name = 'gru_1'

    intermediate_model = Model(inputs = model.input, outputs = model.get_layer(layer_name).output)

    print intermediate_model.predict(x)

  • 相关阅读:
    UML常见工具之Powerdesigner
    在webForm中WebRequest\WebClient\WebBrowser获取远程页面源码的三种方式(downmoon)
    忍不住了,我来说两句,从一道面试题说起
    《UML用户指南第二版》再次温读笔记(一)(downmoon)
    Database Project requires local SQL 2005 instance的解决方案(downmoon)
    JDBC Driver For SQL2000/2005/2008
    服务器更新dll后导致网站崩溃,重启iis也无效的一种解决方案(downmoon)
    白孩儿一个网上流传的故事[生活感悟]
    vs2008中js的语法提示及修正功能(downmoonn)
    Contoso 大学 2 – 实现基本的增删改查
  • 原文地址:https://www.cnblogs.com/wh228/p/9863641.html
Copyright © 2011-2022 走看看