zoukankan      html  css  js  c++  java
  • 将libFM模型变换成tensorflow可serving的形式

    fm_model是libFM生成的模型

    model.ckpt是可以tensorflow serving的模型结构

    亲测输出正确。

    代码:

     1 import tensorflow as tf
     2 
     3 # libFM model
     4 def load_fm_model(file_name):
     5     state = ''
     6     fid = 0
     7     max_fid = 0
     8     w0 = 0.0
     9     wj = {}
    10     v = {}
    11     k = 0
    12     with open(file_name) as f:
    13         for line in f:
    14             line = line.rstrip()
    15             if 'global bias W0' in line:
    16                 state = 'w0'
    17                 fid = 0
    18                 continue
    19             elif 'unary interactions Wj' in line:
    20                 state = 'wj'
    21                 fid = 0
    22                 continue
    23             elif 'pairwise interactions Vj,f' in line:
    24                 state = 'v'
    25                 fid = 0
    26                 continue
    27 
    28             if state == 'w0':
    29                 fv = float(line)
    30                 w0 = fv
    31             elif state == 'wj':
    32                 fv = float(line)
    33                 if fv != 0:
    34                     wj[fid] = fv
    35                 fid += 1
    36                 max_fid = max(max_fid, fid)
    37             elif state == 'v':
    38                 fv = [float(_v) for _v in line.split(' ')]
    39                 k = len(fv)
    40                 if any([_v!=0 for _v in fv]):
    41                     v[fid] = fv
    42                 fid += 1
    43                 max_fid = max(max_fid, fid)
    44     return w0, wj, v, k, max_fid
    45 
    46 _w0, _wj, _v, _k, _max_fid = load_fm_model('libfm_model_file')
    47 
    48 # max feature_id
    49 n = _max_fid
    50 print 'n', n
    51 
    52 # vector dimension
    53 k = _k
    54 print 'k', k
    55 
    56 # write fm algorithm
    57 w0 = tf.constant(_w0)
    58 w1c = tf.constant([_wj.get(fid, 0) for fid in xrange(n)], shape=[n])
    59 w1 = tf.Variable(w1c)
    60 #print 'w1', w1
    61 
    62 vec = []
    63 for fid in xrange(n):
    64     vec.append(_v.get(fid, [0]*k))
    65 w2c = tf.constant(vec, shape=[n,k])
    66 w2 = tf.Variable(w2c)
    67 print 'w2', w2
    68 
    69 # inputs
    70 x = tf.placeholder(tf.string, [None])
    71 batch = tf.shape(x)[0]
    72 x_s = tf.string_split(x)
    73 inds = tf.stack([tf.cast(x_s.indices[:,0], tf.int64), tf.string_to_number(x_s.values, tf.int64)], axis=1)
    74 x_sparse = tf.sparse.SparseTensor(indices=inds, values=tf.ones([tf.shape(inds)[0]]), dense_shape=[batch,n])
    75 x_ = tf.sparse.to_dense(x_sparse)
    76 
    77 w2_rep = tf.reshape(tf.tile(w2, [batch,1]), [-1,n,k])
    78 print 'w2_rep', w2_rep
    79 
    80 x_rep = tf.reshape(tf.tile(tf.reshape(x_, [batch*n, 1]), [1,k]), [-1,n,k])
    81 print 'x_rep', x_rep
    82 x_rep2 = tf.square(x_rep)
    83 
    84 #print tf.multiply(w2_rep,x_rep)
    85 #print tf.reduce_sum(tf.multiply(w2_rep,x_rep), axis=1)
    86 q = tf.square(tf.reduce_sum(tf.multiply(w2_rep, x_rep), axis=1))
    87 h = tf.reduce_sum(tf.multiply(tf.square(w2_rep), x_rep2), axis=1)
    88 
    89 y = w0 + tf.reduce_sum(tf.multiply(x_, w1), axis=1) +
    90     1.0/2 * tf.reduce_sum(q-h, axis=1)
    91 
    92 saver = tf.train.Saver()
    93 with tf.Session() as sess:
    94     sess.run(tf.global_variables_initializer())
    95     #a = sess.run(y, feed_dict={x_:x_train,y_:y_train,batch:70})
    96     #print a
    97     save_path = "./model.ckpt"
    98     tf.saved_model.simple_save(sess, save_path, inputs={"x": x}, outputs={"y": y})

    参考:

    https://blog.csdn.net/u010159842/article/details/78789355 (开头借鉴此文,但其有不少细节错误)

    https://www.tensorflow.org/guide/saved_model

    http://nowave.it/factorization-machines-with-tensorflow.html

  • 相关阅读:
    106. Construct Binary Tree from Inorder and Postorder Traversal
    105. Construct Binary Tree from Preorder and Inorder Traversal
    449. Serialize and Deserialize BST
    114. Flatten Binary Tree to Linked List
    199. Binary Tree Right Side View
    173. Binary Search Tree Iterator
    98. Validate Binary Search Tree
    965. Univalued Binary Tree
    589. N-ary Tree Preorder Traversal
    eclipse设置总结
  • 原文地址:https://www.cnblogs.com/yaoyaohust/p/10472780.html
Copyright © 2011-2022 走看看