zoukankan      html  css  js  c++  java
  • tf.keras 模型 多个输入 tf.data.Dataset

    import tensorflow as tf 
    a = tf.keras.layers.Input(batch_shape=(None,10, 1))
    b = tf.keras.layers.Input(batch_shape=(None,1))
    
    fc1 = tf.keras.layers.Dense(16,'relu')(a)
    fc2 = tf.keras.layers.Dense(16,'relu')(b)
    
    fc1 = tf.keras.layers.Lambda(lambda x: x[:,0,:])(fc1)
    reshape = tf.keras.layers.Lambda(lambda x: tf.reshape(x,(-1, 16)))(fc1)
    hidden = tf.keras.layers.concatenate([reshape, fc2],axis=-1)
    inputs = [a, b]
    outputs = hidden
    print(hidden.shape)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    
    model.compile(optimizer=tf.keras.optimizers.SGD(),
                  loss=tf.keras.losses.mean_squared_error)
    
    import numpy as np
    data1 = np.random.rand(10, 10, 1)
    data2 = np.random.rand(10, 1)
    label  = np.random.rand(10, 32)
    
    dataset1 = tf.data.Dataset.from_tensor_slices((data1, data2))
    dataset2 = tf.data.Dataset.from_tensor_slices(label)
    
    dataset  = tf.data.Dataset.zip((dataset1, dataset2)).batch(10).repeat()
    
    model.fit(dataset, epochs=5, steps_per_epoch=30)
    

    参考文献
    [1] tensorflow使用tf.keras.Mode写模型并使用tf.data.Dataset作为数据输入
    [2] Tensorflow keras入门教程
    [3] 使用 tf.data 加载 NumPy 数据

  • 相关阅读:
    Day60
    Day53
    Day50
    Day49
    Day48
    Day47
    Day46(2)
    Day46(1)
    Day45
    Day44
  • 原文地址:https://www.cnblogs.com/lixyuan/p/12818941.html
Copyright © 2011-2022 走看看