zoukankan      html  css  js  c++  java
  • TensorFlow模型存储为PB形式【BILSTM+CRF 】

    ## TensorFlow模型存储为PB形式【BILSTM+CRF 】
    为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Serving server来部署模型,必须选择SavedModel格式。

    **SavedModel包含啥?**
    一个比较完整的SavedModel模型包含以下内容:

    > assets/
    > assets.extra/
    > variables/
    >         - variables.data-*****-of-*****
    >         - variables.index
    > saved_model.pb

    saved_model.pb是MetaGraphDef,它包含图形结构。
    variables文件夹保存训练所习得的权重。
    assets文件夹可以添加可能需要的外部文件,assets.extra是一个库可以添加其特定assets的地方。


    MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。

    assets和assets.extra是可选的,比如本次保存的模型只包含以下的内容:

    > variables/
    >        variables.data-*****-of-*****
    >        variables.index
    > saved_model.pb

    一、 训练结果checkPoint转存为PB
    1.读取checkPoint数据

    graph2 = tf.Graph()
    with graph2.as_default():
    m = BiLSTM_CRF(args)
    saver = tf.train.import_meta_graph("{}model-35640.meta".format('model_path'))
    with tf.Session(graph=graph2) as session:
    saver.restore(session, tf.train.latest_checkpoint('model_path')) #加载ckpt模型
    export_model(session, m)

    2.转存为PB形式
    **参数定义**

    a = session.graph.get_tensor_by_name("a:0")
    b = session.graph.get_tensor_by_name("b:0")
    c = session.graph.get_tensor_by_name("c:0")
    d = session.graph.get_tensor_by_name("d:0")

    x = session.graph.get_tensor_by_name('x:0')
    y = session.graph.get_tensor_by_name('y:0')
    **结构定义**

    **signature对象**,这个对象包含了计算图中输入与输出张量的键值对信息,键即是张量名,值即是protobuff结构的张量。

    prediction_signature = signature_def_utils.build_signature_def(
    inputs={"a": utils.build_tensor_info(a), # 将张量转为protobuff结构的快捷方法,也就是说下面的输入abcd 以及输出 x y都是经过该函数处理之后的结果。
    "b": utils.build_tensor_info(b), # Protobuf是一种平台无关、语言无关、可扩展且轻便高效的序列化数据结构的协议,可以用于网络通信和数据存储。
    "c": utils.build_tensor_info(c),
    "d": utils.build_tensor_info(d)},

    outputs={
    "x": utils.build_tensor_info(x),
    "y": utils.build_tensor_info(y)},

    method_name=signature_constants.PREDICT_METHOD_NAME)

    export_path = 'result_path'
    if os.path.exists(export_path):
    os.system("rm -rf " + export_path)
    print("Export the model to {}".format(export_path))


    **图定义、存储**

    try:
    legacy_init_op = tf.group(
    tf.tables_initializer(), name='legacy_init_op')
    builder = saved_model_builder.SavedModelBuilder(export_path)

    #可以自己定义tag,在签名的定义上更加灵活。
    一个模型可以包含不同的MetaGraphDef,保存图形的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。

    builder.add_meta_graph_and_variables(
    session, [tag_constants.SERVING], # 系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。
    clear_devices=True,
    signature_def_map={
    'predict_images':
    prediction_signature,
    },
    # legacy_init_op=legacy_init_op,
    main_op=tf.tables_initializer(),
    strip_default_attrs=True
    )


    builder.save()
    print('Done exporting!')

    except Exception as e:
    print("Fail to export saved model, exception: {}".format(e))


    2.加载PB Model

    session = tf.Session(graph=tf.Graph())
    model_file_path = 'data_path_save'
    meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path)
    model_graph_signature = list(meta_graph.signature_def.items())[0][1] #存在多个图结构时
    output_tensor_names = []
    output_op_names = []
    for output_item in model_graph_signature.outputs.items():
    output_op_name = output_item[0]
    output_op_names.append(output_op_name)
    output_tensor_name = output_item[1].name
    output_tensor_names.append(output_tensor_name)
    print("load model finish!")


    3.进行预测


    sentences = {}
    # 测试pb模型
    for test_x in [['周杰伦爱喝奶茶'],['习主席在陕西'],['刘德华的老婆是不是朱丽倩'],['张学友的老婆是谁咱不知道'],['解放牌卡车是不是在北京生产的']]:
    sentences, seq_len_list = _preprocess(test_x)

    feed_dict_map = {}
    for input_item in model_graph_signature.inputs.items():
    input_op_name = input_item[0]
    input_tensor_name = input_item[1].name
    feed_dict_map[input_tensor_name] = sentences[input_op_name]

    logits, transition_params = session.run(output_tensor_names, feed_dict=feed_dict_map)

    tag = predict(logits,transition_params, seq_len_list)

    label2tag = {}
    for tag, label in tag2label.items():
    label2tag[label] = tag if label != 0 else label

    tag = [label2tag[label] for label in label_list[0]]

    PER, LOC, ORG = get_entity(tag, list(''.join(test_x).strip()))
    print('PER: {} LOC: {} ORG: {}'.format(PER, LOC, ORG))

  • 相关阅读:
    dd——留言板再加验证码功能
    怎样去除织梦版权信息中的Power by DedeCms
    数据结构和算法的选择
    数据结构和算法9——哈希表
    数据结构与算法8——二叉树
    数据结构与算法7——高级排序
    数据结构与算法6——递归
    数据结构和算法5——链表
    数据结构与算法4——栈和队列
    数据结构与算法3——简单排序(冒泡、选择、插入排序)
  • 原文地址:https://www.cnblogs.com/hanhaotian/p/12875695.html
Copyright © 2011-2022 走看看