zoukankan      html  css  js  c++  java
  • 命名实体识别之动态融合不同bert层的特征(基于tensorflow)

        num_labels = self.config.relation_num
        bert_config_file = self.config.bert_config_file
        bert_config = BertConfig.from_json_file(bert_config_file)
    
        model = BertModel(
            config=bert_config,
            is_training=self.is_training,  # 微调
            input_ids=self.input_x_word,
            input_mask=self.input_mask,
            token_type_ids=None,
            use_one_hot_embeddings=False)
    
        # If you want to use the token-level output, use model.get_sequence_output()
        # output_layer = model.get_pooled_output() # [?,768]
        # print("output_layer.shape:",output_layer)
        output_layer = model.get_sequence_output()
        print("output_layer.shape:",output_layer.shape)
        hidden_size = output_layer.shape[-1].value # 768
        print("=============================")
        print("打印融合特征的相关张量的形状")
        layer_logits = []
        for i, layer in enumerate(model.all_encoder_layers):
            print("layer:",layer)
            layer_logits.append(
                tf.layers.dense(
                    layer, 1,
                    kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                    name="layer_logit%d" % i
                )
            )
        print("np.array(layer_logits).shape:",np.array(layer_logits).shape)
        layer_logits = tf.concat(layer_logits, axis=2)  # 第三维度拼接[batchsize,max_len,12]
        print("layer_logits.shape:",layer_logits.shape)
        layer_dist = tf.nn.softmax(layer_logits) #[batchszie,max_len,12]
        print("layer_dist.shape:",layer_dist.shape)
        # [batchsize,max_len,12,768]
        seq_out = tf.concat([tf.expand_dims(x, axis=2) for x in model.all_encoder_layers], axis=2)
        print("seq_out.shape:",seq_out.shape)
        #[batchsize,max_len,1,12] × [batchsize,max_len,12,768]
        pooled_output = tf.matmul(tf.expand_dims(layer_dist, axis=2), seq_out)
        pooled_output = tf.squeeze(pooled_output, axis=2)
        print("pooled_output.shape:",pooled_output.shape)
        pooled_layer = pooled_output
        print("=============================")

    输出:

    =============================
    打印融合特征的相关张量的形状
    layer: Tensor("bert/encoder/Reshape_2:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_3:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_4:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_5:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_6:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_7:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_8:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_9:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_10:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_11:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_12:0", shape=(?, ?, 768), dtype=float32)
    layer: Tensor("bert/encoder/Reshape_13:0", shape=(?, ?, 768), dtype=float32)
    np.array(layer_logits).shape: (12,)
    layer_logits.shape: (?, ?, 12)
    layer_dist.shape: (?, ?, 12)
    seq_out.shape: (?, ?, 12, 768)
    pooled_output.shape: (?, ?, 768)
    =============================

    说明:

    bert中文base版总共有12层,也就是每一层都可以输出相应的特征,我们可以使用model.all_encoder_layers来获取,然后我们将每一层的768维度的特征映射成1维,对每一个特征进行最后一个维度的拼接后经过softmax层,得到每一层特征相对应的权重,最后经过[batchsize,max_len,1,12] × [batchsize,max_len,12,768],得到[batchszie,max_len,1,768],去除掉一维得到[batchsize,max_len,768],这样我们就得到了可以动态选择的特征,接下来就可以利用该特征进行相关的微调任务了。

    比如:

    output_weights = tf.get_variable(
            "output_weights", [num_labels, hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02)) # [768, 11]
    print("output_weights.shape:",output_weights)
    
    output_bias = tf.get_variable(
            "output_bias", [num_labels], initializer=tf.zeros_initializer()) # [11]
    
    with tf.variable_scope("loss"):
    
          logits = tf.matmul(pooled_layer, output_weights, transpose_b=True) # [?,?,768]*[768,11] = [?,?,11]
          #print("logits.shape:",logits.shape)
          self.logits = tf.nn.bias_add(logits, output_bias)
          self.probabilities = tf.nn.softmax(self.logits, axis=-1)
          log_probs = tf.nn.log_softmax(self.logits, axis=-1) # [?,?,11]
          print("log_probs.shape:",log_probs.shape)
    
          self.predictions = tf.argmax(self.logits, axis=-1, name="predictions")
    
          one_hot_labels = tf.one_hot(self.input_relation, depth=num_labels, dtype=tf.float32) # [?,512,11]
          #print(one_hot_labels)
          #print("one_hot_labels.shape:",one_hot_labels.shape)
          self.per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
          #print("self.per_example_loss.shape:",self.per_example_loss.shape)
          self.loss = tf.reduce_mean(self.per_example_loss)
  • 相关阅读:
    英语词汇——day 1
    英语词汇——day 2
    PHP的流程控制语句(上)
    思维导图——四级词汇1
    PHP语句块中使用date()函数时需注意wampserver的设置
    (转)Linux服务器调优
    (转)linux服务器安全配置攻略
    mysql 创建[序列],功能类似于oracle的序列
    计算服务器最大并发量http协议请求以webSphere服务器为例考虑线程池
    Spring中ApplicationContextAware接口的说明
  • 原文地址:https://www.cnblogs.com/xiximayou/p/14128613.html
Copyright © 2011-2022 走看看