zoukankan      html  css  js  c++  java
  • 在keras下实现多个模型的融合

    在keras下实现多个模型的融合

    小风风12580 2019-09-30 10:42:00 1105 收藏 7
    展开
    在网上搜过发现关于keras下的模型融合框架其实很简单,奈何网上说了一大堆,这个东西官方文档上就有,自己写了个demo:

    # Function:基于keras框架下实现,多个独立任务分类
    # Writer: PQF
    # Time: 2019/9/29

    import numpy as np
    from keras.layers import Input, Dense
    from keras.models import Model
    import tensorflow as tf

    # 生成训练集
    dataset_size = 128*3
    rdm = np.random.RandomState(1)
    X = rdm.rand(dataset_size,2)
    Y1 = [[int(x1+x2<1)] for (x1,x2) in X]
    Y2 = [[int(x1+x2*x2<0.5)] for (x1,x2) in X]

    X_train = X[:-2]
    Y_train1 = Y1[:-2]
    Y_train2 = Y2[:-2]

    X_test = X[-2:dataset_size]
    Y_test1 = Y1[-2:dataset_size]
    Y_test2 = Y2[-2:dataset_size]

    #网络一
    input = Input(shape=(2,))
    x = Dense(units=16,activation='relu')(input)
    output = Dense(units=1,activation='sigmoid',name='output1')(x)

    #网络二
    input2 = Input(shape=(2,))
    x2 = Dense(units=16,activation='relu')(input2)
    output2 = Dense(units=1,activation='sigmoid',name='output2')(x2)

    #模型合并
    model = Model(inputs=[input,input2],outputs=[output,output2])
    model.summary()

    model.compile(optimizer='rmsprop',loss='binary_crossentropy',loss_weights=[1.0,1.0])
    model.fit([X_train,X_train],[Y_train1,Y_train2],batch_size=48,epochs=200)


    print('x_test is : ')
    print(X_test)
    print('y_test1 is : ')
    print(Y_test1)
    print('y_test2 is : ')
    print(Y_test2)

    predict = model.predict([X_test,X_test])
    print('prediction is : ')
    print(predict[0])
    print(predict[1])

    ————————————————
    版权声明:本文为CSDN博主「小风风12580」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    原文链接:https://blog.csdn.net/weixin_43392276/java/article/details/101757173

  • 相关阅读:
    Django的forms.ModelForm自定义特殊条件认证。
    对象的属性输出,魔法方法__dict__
    Django从model对象里面提取出字段与属性,并转换成字典。
    刚刚想起猴子布丁,查了点资料,自己实践了下,记录汇总下。
    HTTP通信传输过程详解。
    jsp->jar
    Python overall structer
    SaaS成熟度模型分级:
    FW: linux screen -recorder by ffcast convert
    time-based DB
  • 原文地址:https://www.cnblogs.com/think90/p/12981421.html
Copyright © 2011-2022 走看看