zoukankan      html  css  js  c++  java
  • FM算法keras实现

    
    import numpy as np
    import pandas as pd
    import tensorflow as tf
    import keras
    import os
    
    import matplotlib.pyplot as plt
    
    from keras.layers import Layer,Dense,Dropout,Input
    from keras import Model,activations
    from keras.optimizers import Adam
    from keras import backend as K
    from keras.layers import Layer
    from sklearn.datasets import load_breast_cancer
    
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"
    class FM(Layer):
        def __init__(self, output_dim, latent=10,  activation='relu', **kwargs):
            self.latent = latent
            self.output_dim = output_dim
            self.activation = activations.get(activation)
            super(FM, self).__init__(**kwargs)
    
        def build(self, input_shape):
            self.b = self.add_weight(name='W0',
                                      shape=(self.output_dim,),
                                      trainable=True,
                                     initializer='zeros')
            self.w = self.add_weight(name='W',
                                     shape=(input_shape[1], self.output_dim),
                                     trainable=True,
                                     initializer='random_uniform')
            self.v= self.add_weight(name='V',
                                     shape=(input_shape[1], self.latent),
                                     trainable=True,
                                    initializer='random_uniform')
            super(FM, self).build(input_shape)
    
        def call(self, inputs, **kwargs):
            x = inputs
            x_square = K.square(x)
    
            xv = K.square(K.dot(x, self.v))
            xw = K.dot(x, self.w)
    
            p = 0.5*K.sum(xv-K.dot(x_square, K.square(self.v)), 1)
    
            rp = K.repeat_elements(K.reshape(p, (-1, 1)), self.output_dim, axis=-1)
    
            f = xw + rp + self.b
    
            output = K.reshape(f, (-1, self.output_dim))
    
            return output
    
        def compute_output_shape(self, input_shape):
            assert input_shape and len(input_shape)==2
            return input_shape[0],self.output_dim
    
    
    data = load_breast_cancer()["data"]
    target = load_breast_cancer()["target"]
    
    K.clear_session()
    print(target)
    inputs = Input(shape=(30,))
    out = FM(20)(inputs)
    out = Dense(15, activation='sigmoid')(out)
    out = Dense(1, activation='sigmoid')(out)
    
    model=Model(inputs=inputs, outputs=out)
    model.compile(loss='mse',
                  optimizer='adam',
                  metrics=['acc'])
    model.summary()
    
    h=model.fit(data, target, batch_size=1, epochs=10, validation_split=0.2)
    
    #%%
    
    plt.plot(h.history['acc'],label='acc')
    plt.plot(h.history['val_acc'],label='val_acc')
    plt.xlabel('epoch')
    plt.ylabel('acc')
    
    #%%
    
  • 相关阅读:
    JavaScript+运算符总结
    【总结】HTMl5的sessionStorage和localStorage
    移动H5前端性能优化指南(转自ISUX)
    最新个人H5游戏大作——《择花的少女》
    类似天猫那样的侧边导航栏的快速实现
    JQuery实现banner图片的轮播效果
    实现数字电视机顶盒画面的纯键盘和遥控操作网页
    广播的动态静态注册
    Activity 与 fragment 生命周期
    activitycollector
  • 原文地址:https://www.cnblogs.com/zhouyu0-0/p/12293880.html
Copyright © 2011-2022 走看看