zoukankan      html  css  js  c++  java
  • [阿里DIN] 模型保存,加载和使用

    [阿里DIN] 模型保存,加载和使用

    0x00 摘要

    Deep Interest Network(DIN)是阿里妈妈精准定向检索及基础算法团队在2017年6月提出的。其针对电子商务领域(e-commerce industry)的CTR预估,重点在于充分利用/挖掘用户历史行为数据中的信息。

    本系列文章会解读论文以及源码,顺便梳理一些深度学习相关概念和TensorFlow的实现。

    本文是系列第 12 篇 :介绍DIN模型的保存,加载和使用。

    0x01 TensorFlow模型

    1.1 模型文件

    TensorFlow模型会保存在checkpoint相关文件中。因为TensorFlow会将计算图的结构和图上参数取值分开保存,所以保存后在相关文件夹中会出现3个文件。

    下面就是DIN,DIEN相关生成的文件,可以通过名称来判别。

    checkpoint				
    
    ckpt_noshuffDIN3.data-00000-of-00001
    ckpt_noshuffDIN3.meta
    ckpt_noshuffDIN3.index
    
    ckpt_noshuffDIEN3.data-00000-of-00001	
    ckpt_noshuffDIEN3.index			
    ckpt_noshuffDIEN3.meta
    

    所以我们可以认为和保存的模型直接相关的是以下这四个文件:

    • checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是TensorFlow自动生成且自动维护的。在 checkpoint文件中维护了由一个TensorFlow持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.
    • .meta文件 保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构。
      TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名。
    • .index文件保存了当前参数名。
    • model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。

    1.2 freeze_graph

    正如前文所述,tensorflow在训练过程中,通常不会将权重数据保存的格式文件里,反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,所以便有了freeze_graph.py脚本文件用来将这两文件整合合并成一个文件。

    freeze_graph.py是怎么做的呢?

    • 它先加载模型文件
    • 提供checkpoint文件地址后,它从checkpoint文件读取权重数据初始化到模型里的权重变量;
    • 将权重变量转换成权重常量 (因为常量能随模型一起保存在同一个文件里);
    • 再通过指定的输出节点没用于输出推理的Op节点从图中剥离掉;
    • 使用tf.train.writegraph保存图,这个图会提供给freeze_graph使用;
    • 再使用freeze_graph重新保存到指定的文件里;

    0x02 DIN代码

    因为 DIN 源码中没有实现此部分,所以我们需要自行添加。

    2.1 输出结点

    首先,在model.py中,需要声明输出结点。

    def build_fcn_net(self, inp, use_dice = False):
        .....
        # 此处需要给 y_hat 添加一个name
        self.y_hat = tf.nn.softmax(dnn3, name='final_output') + 0.00000001
    

    2.2 保存函数

    其次,需要添加一个保存函数,调用 freeze_graph 来进行保存。

    需要注意几点:

    • write_graph 的 as_text 参数默认是 True,我们这里设置为 False。有的环境如果设置为 True 会有问题;
    • 因为write_graph 的 as_text 参数做了设置,所以freeze_graph的参数也做相应设置: input_binary=True
    • input_checkpoint 参数需要针对DIN或者DIEN做相应调整;

    具体代码如下:

    def din_freeze_graph(sess):
        # 模型持久化,将变量值固定
        output_graph_def = convert_variables_to_constants(
                sess=sess,
                input_graph_def=sess.graph_def, # 等于:sess.graph_def
                output_node_names=['final_output']) # 如果有多个输出节点,以逗号隔开
        tf.train.write_graph(output_graph_def, 'dnn_best_model', 'model.pb', False)
    
        freeze_graph.freeze_graph(
                input_graph='./dnn_best_model/model.pb',
                input_saver='',
                input_binary=True,
                input_checkpoint='./dnn_best_model/ckpt_noshuffDIN3',
                output_node_names='final_output', # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
                restore_op_name='save/restore_all',
                filename_tensor_name='save/Const:0',
                output_graph='./dnn_best_model/frozen_model.pb',
                clear_devices=False,
                initializer_nodes=''
                )
    
    

    2.2 调用保存

    我们在train函数中,存储模型之后,进行调用。

    def train(...):
                    if (iter % save_iter) == 0:
                        print('save model iter: %d' %(iter))
                        model.save(sess, model_path+"--"+str(iter))
                        freeze_graph(sess) # 此处调用
    

    0x03 验证

    3.1 加载

    加载函数如下:

    def load_graph(fz_gh_fn):
        with tf.gfile.GFile(fz_gh_fn, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
    
            with tf.Graph().as_default() as graph:
                tf.import_graph_def(
                    graph_def,
                    input_map=None,
                    return_elements=None,
                    name="prefix"  # 此处可以自己修改
                )
        return graph
    

    调用加载函数如下,我们在加载之后,打印出图中对应节点:

    graph = load_graph('./dnn_best_model/frozen_model.pb')
    for op in graph.get_operations():
        print(op.name, op.values())
    

    从打印结果我们可以看出来,有些op是Inputs相关,final_output节点则是我们之前设定的。

    (u'prefix/Inputs/mid_his_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
    (u'prefix/Inputs/cat_his_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
    (u'prefix/Inputs/uid_batch_ph', (<tf.Tensor 'prefix/Inputs/uid_batch_ph:0' shape=(?,) dtype=int32>,))
    (u'prefix/Inputs/mid_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_batch_ph:0' shape=(?,) dtype=int32>,))
    (u'prefix/Inputs/cat_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_batch_ph:0' shape=(?,) dtype=int32>,))
    (u'prefix/Inputs/mask', (<tf.Tensor 'prefix/Inputs/mask:0' shape=(?, ?) dtype=float32>,))
    (u'prefix/Inputs/seq_len_ph', (<tf.Tensor 'prefix/Inputs/seq_len_ph:0' shape=(?,) 
                                   
    ......            
                                   
    (u'prefix/final_output', (<tf.Tensor 'prefix/final_output:0' shape=(?, 2) dtype=float32>,))
    

    3.2 验证

    验证数据可以自己炮制,或者就是从测试数据中取出两条即可,我们的验证文件名字为 local_predict_splitByUser

    0	A3BI7R43VUZ1TY	B00JNHU0T2	Literature & Fiction	0989464105B00B01691C14778097321608442845	BooksLiterature & FictionBooksBooks
    
    1	A3BI7R43VUZ1TY	0989464121	Books	0989464105B00B01691C14778097321608442845	BooksLiterature & FictionBooksBooks
    

    验证代码如下,其中feed_dict如何填充,需要根据上节的输出结果来进行相关配置。

    def predict(
            graph,
            predict_file = "local_predict_splitByUser",
            uid_voc = "uid_voc.pkl",
            mid_voc = "mid_voc.pkl",
            cat_voc = "cat_voc.pkl",
            batch_size = 128,
            maxlen = 100):
        gpu_options = tf.GPUOptions(allow_growth=True)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph = graph) as sess:
            predict_data = DataIterator(predict_file, uid_voc, mid_voc, cat_voc, batch_size, maxlen)
            for src, tgt in predict_data:
                uids, mids, cats, mid_his, cat_his, mid_mask, target, sl, noclk_mids, noclk_cats = prepare_data(src, tgt, maxlen, return_neg=True)
                final_output = "prefix/final_output:0"
                feed_dict = {
                    'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
                    'prefix/Inputs/cat_his_batch_ph:0':cat_his,
                    'prefix/Inputs/uid_batch_ph:0':uids,
                    'prefix/Inputs/mid_batch_ph:0':mids,
                    'prefix/Inputs/cat_batch_ph:0':cats,
                    'prefix/Inputs/mask:0':mid_mask,
                    'prefix/Inputs/seq_len_ph:0':sl
                }
                y_hat = sess.run(final_output, feed_dict = feed_dict)
                print(y_hat)
    

    预测结果如下:

    [[0.95820646 0.04179354]
     [0.09431148 0.9056886 ]]
    

    3.3 为什么要在tensor后面加:0

    在上节中,我们可以看到在feed_dict之中,给定的tensor名字后面都带了 :0

    feed_dict = {
        'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
        'prefix/Inputs/cat_his_batch_ph:0':cat_his,
        'prefix/Inputs/uid_batch_ph:0':uids,
        'prefix/Inputs/mid_batch_ph:0':mids,
        'prefix/Inputs/cat_batch_ph:0':cats,
        'prefix/Inputs/mask:0':mid_mask,
        'prefix/Inputs/seq_len_ph:0':sl
    }
    

    这里需要注意,TensorFlow的运算结果不是一个数,而是一个张量结构。张量的命名形式:“node : src_output”,node为节点的名称,src_output 表示当前张量来自来自节点的第几个输出。

    在我们这里,prefix/Inputs/mid_batch_ph 是操作节点,prefix/Inputs/mid_batch_ph:0 才是变量的名字。冒号后面的数字编号表示这个张量是计算节点上的第几个结果

    0xFF 参考

    【TensorFlow】freeze_graph

    [深度学习] TensorFlow中模型的freeze_graph

    TensorFlow模型冷冻以及为什么tensor名字要加:0

    tensorflow实战笔记(19)----使用freeze_graph.py将ckpt转为pb文件

    Tensorflow-GraphDef、MetaGraph、CheckPoint

  • 相关阅读:
    卡特兰数
    hdu 1023 Train Problem II
    hdu 1022 Train Problem
    hdu 1021 Fibonacci Again 找规律
    java大数模板
    gcd
    object dection资源
    Rich feature hierarchies for accurate object detection and semantic segmentation(RCNN)
    softmax sigmoid
    凸优化
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/14019176.html
Copyright © 2011-2022 走看看