zoukankan      html  css  js  c++  java
  • bert文本分类模型保存为savedmodel方式

    默认bert是ckpt,在进行后期优化和部署时,savedmodel方式更加友好写。

    train完成后,调用如下函数:

    def save_savedmodel(estimator, serving_dir, seq_length, is_tpu_estimator):
        feature_map = {
            "input_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='input_ids'),
            "input_mask": tf.placeholder(tf.int32, shape=[None, seq_length], name='input_mask'),
            "segment_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids'),
            "label_ids": tf.placeholder(tf.int32, shape=[None], name='label_ids'),
        }
        serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map)
        estimator.export_savedmodel(serving_dir,
                                    serving_input_receiver_fn,
                                    strip_default_attrs=True)
        print("保存savedmodel")

    estimator:estimator = Estimator(model_fn=model_fn,params={},config=run_config)

    serving_dir:存储目录

    seq_length:样本长度

    is_tpu_estimator: tpu标志位

     
     
  • 相关阅读:
    矩阵快速幂
    快速排序
    用闭包来实现命令模式
    闭包和面向对象设计
    闭包及其作用
    阿里笔试
    如何在节点上添加样式
    getComputedStyle与currentStyle获取样式(style/class)
    今日头条笔试
    牛客网JavaScript编程规范
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13892636.html
Copyright © 2011-2022 走看看