zoukankan      html  css  js  c++  java
  • bert-as-service输出分类结果

    bert-as-service: Mapping a variable-length sentence to a fixed-length vector using BERT model

    默认情况下bert-as-service只提供固定长度的特征向量,如果想要直接获取分类预测结果呢?

    bert提供了的run_classifier.py 以训练分类模型,同时bert提供了离线评估的方法。

    一些可能的部署思路

    • bert基于tensorflow实现,可以参考tensorflow-serving对外提供部署服务
    • 参考bert代码修改离线接口为在线推断,基于flask/django提供部署服务
    • 修改bert-as-service提供高效在线预测服务

    bert-as-service的强大可以参考:Serving Google BERT in Production using Tensorflow and ZeroMQ

    修改bert-as-service提供分类预测

    思路:https://github.com/hanxiao/bert-as-service/issues/213

    bert-as-service 默认情况下,不会加载分类层

    1. 加载模型的同时加载分类层的权重和bias
    2. 添加分类层

    graph.py#L79中添加

                if args.pooling_strategy == PoolingStrategy.CLASSIFICATION:
                     hidden_size = 768
                     output_weights = tf.get_variable(
                         "output_weights", [args.num_labels, hidden_size],
                         )
    
                      output_bias = tf.get_variable(
                         "output_bias", [args.num_labels])
    
                  tvars = tf.trainable_variables()		            
    

    注意:在加载权重和bias的时候不要定义初始化方法,否则会从初始化方法进行加载,而不是微调模型。

    graph.py#L127添加

                    elif args.pooling_strategy == PoolingStrategy.CLASSIFICATION:
                         # pooled = tf.squeeze(encoder_layer[:, 0:1, :], axis=1)
                         logits = tf.matmul(pooled, output_weights, transpose_b=True)
                         logits = tf.nn.bias_add(logits, output_bias)
                         pooled = tf.nn.softmax(logits, axis=-1)
    

    具体代码github

  • 相关阅读:
    Git 总结
    .net报错大全
    对于堆和栈的理解
    html 局部打印
    c#面试问题总结
    算法题总结
    h5-plus.webview
    堆和栈,引用类型,值类型,指令,指针
    .NET framework具体解释
    前端之间的url 传值
  • 原文地址:https://www.cnblogs.com/zyl007/p/12995744.html
Copyright © 2011-2022 走看看