zoukankan      html  css  js  c++  java
  • bert,albert的快速训练和预测

      随着预训练模型越来越成熟,预训练模型也会更多的在业务中使用,本文提供了bert和albert的快速训练和部署,实际上目前的预训练模型在用起来时都大致相同。

      基于不久前发布的中文数据集chineseGLUE,将所有任务分成四大类:文本分类,句子对判断,实体识别,阅读理解。同类可以共享代码,除上面四个任务之外,还加了一个learning to rank ,基于pair wise的方式的任务,代码见:https://github.com/jiangxinyang227/bert-for-task

      具体使用见readme

      模型定义在每个项目下的model.py文件中,直接调用bert和albert的源码modeling.py将预训练模型引入,将预训练模型作为encoder部分,也可以只作为embedding层,再自己定义encoder部分,总之可以非常方便的接入下游任务网络层,尤其是当你只想使用预训练模型作为embedding层时,我们需要自己些encoder部分。

         bert_config = modeling.BertConfig.from_json_file(self.__bert_config_path)
    
            model = modeling.BertModel(config=bert_config,
                                       is_training=self.__is_training,
                                       input_ids=self.input_ids,
                                       input_mask=self.input_masks,
                                       token_type_ids=self.segment_ids,
                                       use_one_hot_embeddings=False)
            output_layer = model.get_pooled_output()
    
            hidden_size = output_layer.shape[-1].value
            if self.__is_training:
                # I.e., 0.1 dropout
                output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
    
            with tf.name_scope("output"):
                output_weights = tf.get_variable(
                    "output_weights", [self.__num_classes, hidden_size],
                    initializer=tf.truncated_normal_initializer(stddev=0.02))
    
                output_bias = tf.get_variable(
                    "output_bias", [self.__num_classes], initializer=tf.zeros_initializer())
    
                logits = tf.matmul(output_layer, output_weights, transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)
                self.predictions = tf.argmax(logits, axis=-1, name="predictions")

      在训练时加载预训练的参数值来初始化预训练模型的变量,具体在trainer.py文件中

    tvars = tf.trainable_variables()
                (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
                    tvars, self.__bert_checkpoint_path)
    print("init bert model params")
    tf.train.init_from_checkpoint(self.
    __bert_checkpoint_path, assignment_map) print("init bert model params done") sess.run(tf.variables_initializer(tf.global_variables()))

      在预测时可以直接实例化predict.py文件中的Predictor类就会加载checkpoint模型文件,调用类中的predict方法就可以进行预测,在不需要考虑模型代码加密,模型优化等情况下,可以直接线上部署。

    import json
    
    from predict import Predictor
    
    
    with open("config/tnews_config.json", "r") as fr:
        config = json.load(fr)
    
    
    predictor = Predictor(config)
    text = "歼20座舱盖上的两条“花纹”是什么?"
    res = predictor.predict(text)
    print(res)
  • 相关阅读:
    在IDEA上本地更新同步Git中的更改
    protobuf的序列化和反序列化
    关于Pytorch报警告:Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead
    990. 等式方程的可满足性
    死锁
    事务隔离
    Lab-1
    软件测试homework3
    TCP/UDP网络连接的固定写法
    软件测试Homework 2
  • 原文地址:https://www.cnblogs.com/jiangxinyang/p/11882270.html
Copyright © 2011-2022 走看看