zoukankan      html  css  js  c++  java
  • keras Model 3 共享的层

    1 入门

    2 多个输入和输出

    3 共享层

    考虑这样的一个问题:我们要判断连个tweet是否来源于同一个人。

    首先我们对两个tweet进行处理,然后将处理的结构拼接在一起,之后跟一个逻辑回归,输出这两条tweet来自同一个人概率。

    因为我们对两条tweet的处理是相同的,所以对第一条tweet的处理的模型,可以被重用来处理第二个tweet。我们考虑用LSTM进行处理。

    假设我们的输入是两条 280*256的向量

    首先定义输入:

    import keras
    from keras.layers import Input, LSTM, Dense
    from keras.models import Model
    
    tweet_a = Input(shape=(280, 256))
    tweet_b = Input(shape=(280, 256))

    然后我们共享LSTM。共享层很简单,只要实例化层一次,然后在你想处理的tensor上调用你想要应用的次数即可(翻译无力,看代码)

    # This layer can take as input a matrix
    # and will return a vector of size 64
    shared_lstm = LSTM(64)
    
    # When we reuse the same layer instance
    # multiple times, the weights of the layer
    # are also being reused
    # (it is effectively *the same* layer)
    encoded_a = shared_lstm(tweet_a)
    encoded_b = shared_lstm(tweet_b)
    
    # We can then concatenate the two vectors:
    merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=-1)
    
    # And add a logistic regression on top
    predictions = Dense(1, activation='sigmoid')(merged_vector)
    
    # We define a trainable model linking the
    # tweet inputs to the predictions
    model = Model(inputs=[tweet_a, tweet_b], outputs=predictions)
    
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    model.fit([data_a, data_b], labels, epochs=10)

    其实,简单点说,对一个层的多次调用,就是在共享这个层。这里有一个层的节点的概念

    当你在一个输入tensor上调用一个层时,就会生成一个输出tensor,就会在这个层上添加一个节点,这个节点连接着这两个tensor(输入tensor和输出tensor)。当你多次调用同一个层的时,

    这个层生成的节点就会按照0 ,1, 2, 。。以此类推编号。

    那么当一个层有多个节点的时候,我们怎么获取它的输出呢?

    如果直接通过output获取会出错:

    a = Input(shape=(280, 256))
    b = Input(shape=(280, 256))
    
    lstm = LSTM(32)
    encoded_a = lstm(a)
    encoded_b = lstm(b)
    
    lstm.output
    >> AttributeError: Layer lstm_1 has multiple inbound nodes,
    hence the notion of "layer output" is ill-defined.
    Use `get_output_at(node_index)` instead.

    这时候应该通过索引进行调用:

    assert lstm.get_output_at(0) == encoded_a
    assert lstm.get_output_at(1) == encoded_b

    对于输入,也是同样的

    a = Input(shape=(32, 32, 3))
    b = Input(shape=(64, 64, 3))
    
    conv = Conv2D(16, (3, 3), padding='same')
    conved_a = conv(a)
    
    # Only one input so far, the following will work:
    assert conv.input_shape == (None, 32, 32, 3)
    
    conved_b = conv(b)
    # now the `.input_shape` property wouldn't work, but this does:
    assert conv.get_input_shape_at(0) == (None, 32, 32, 3)
    assert conv.get_input_shape_at(1) == (None, 64, 64, 3)
  • 相关阅读:
    python实现七段数码管显示
    词频统计实例
    分形几何中科赫雪花的绘制
    脚本实现自动化绘制
    Android查看数据库方法及工具(转)
    2011年度总结
    Bad NPObject as private data 解决方案
    LINQ学习之旅——LINQ TO SQL继承支持(转载)
    我记录开源系统1.6源码解析(一)
    2011年度总结
  • 原文地址:https://www.cnblogs.com/superxuezhazha/p/10973109.html
Copyright © 2011-2022 走看看