zoukankan      html  css  js  c++  java
  • 【关系抽取-mre-in-one-pass】加载数据(二)

    接上一节加载数据(一)

    上一节我们说到了

    convert_single_example(ex_index, example, label_list, max_seq_length,
                               tokenizer)
    

    这个函数,里面又分别调用了:

    loc, mas, e1_mas, e2_mas = prepare_extra_data(mapping_a, example.locations, FLAGS.max_distance)
    

    而在prepare_extra_data里面调用了两个函数:

    convert_entity_row(mapping, e, max_distance)
    find_lo_hi(mapping, lo)
    

    我们一步步从prepare_extra_data里面看起:

    • 一开始就定义了4个数组:
      res = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
      mas = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
      
      e1_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)
      e2_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)
    

    先总体对这些是什么有个大概的了解:
    (1)res:存储的是相对位置,是一个[128,128]的数组,这里的128是句子的最大长度。这个数组记录的是实体和其它词之间的相对位置。
    (2)mas:存储的是实体的mask矩阵,也就是每个句子中实体出现的位置就是1,其它的就是0,也是一个[128,128]的数组
    (3)e1_mas:在每一对关系中实体1的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。
    (4)e2_mas:在每一对关系中实体2的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。

    • 得到每一个关系的实体集合
    entities = set()
    for loc in locs:
        entities.add(loc[0])
        entities.add(loc[1])
    
    • 接下来是关键了
      for e in entities:
        (lo, hi) = e
        relative_position, _ = convert_entity_row(mapping, e, max_distance)
        sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
        sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
        if sub_lo1 == 0 and sub_hi1 == 0:
          continue
        if sub_lo2 == 0 and sub_hi2 == 0:
          continue
        # col
        res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
        mas[1:, sub_lo1:sub_hi2+1] = 1
    

    我们先看下输出:

    example.text_a = a large database . Traditional information retrieval techniques use a histogram of keywords as the document representation but oral communication may offer additional indices such as the time and is shown on a large database of TV shows . Emotions and other indices
    tokens_a = ['a', 'large', 'database', '.', 'traditional', 'information', 'retrieval', 'techniques', 'use', 'a', 'his', '##to', '##gram', 'of', 'key', '##words', 'as', 'the', 'document', 'representation', 'but', 'oral', 'communication', 'may', 'offer', 'additional', 'indices', 'such', 'as', 'the', 'time', 'and', 'is', 'shown', 'on', 'a', 'large', 'database', 'of', 'tv', 'shows', '.', 'emotions', 'and', 'other', 'indices']
    mapping_a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 11, 12, 13, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]
    example.locations = [((13, 13), (6, 8)), ((19, 20), (24, 24)), ((37, 38), (35, 35))]
    entities = {(6, 8), (13, 13), (35, 35), (37, 38), (24, 24), (19, 20)}
    

    对于每一个实体的位置,调用了relative_position, _ =convert_entity_row(mapping, e, max_distance),这个函数:

    def convert_entity_row(mapping, loc, max_distance):
      """
      convert an entity span(lo,hi) to a relative distance vector of shape [max_seq_length]
      """
      lo, hi = loc
      res = [max_distance] * FLAGS.max_seq_length
      mas = [0] * FLAGS.max_seq_length
      for i in range(FLAGS.max_seq_length):
        if i < len(mapping):
          val = mapping[i]
          if val < lo - max_distance:
            res[i] = max_distance
          elif val < lo:
            res[i] = lo - val
          elif val <= hi:
            res[i] = 0
            mas[i] = 1
          elif val <= hi + max_distance:
            res[i] = val - hi + max_distance
          else:
            res[i] = 2 * max_distance
        else:
          res[i] = 2 * max_distance
      return res, mas
    

    的输出是:

    lo = 6
    hi = 8
    res = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
    mas = [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    relative_position = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
    

    设置最大距离为4。在res中,对于实体而言,其相对位置为0,当实体左边的字和实体左边边界的距离小于定义的最大距离时,值就是距离值,否则左边的就都是最大距离值。同理右边也是这样,只不过是从最大值开始,到最大值的两倍结束,需要注意的是由于是wordpiece拆分的,对于一个单词而言,如果拆分成了几个,那么他们的位置是一致的,比如上面的7,7,7。如果不好理解的话,直接看上面的结果就能理解了。
    对于:

    def find_lo_hi(mapping, value):
      """
      find the boundary of a value in a list
      will return (0,0) if no such value in the list
      """
      try:
        lo = mapping.index(value)
        hi = min(len(mapping) - 1 - mapping[::-1].index(value), FLAGS.max_seq_length)
        return (lo, hi)
      except:
        return (0,0)
    

    这个而言,由于我们会进行wordpiece的拆分,因此实体在分词后的索引有可能是变换的,因此对于hi,我们要反向索引。

    • 接着就是将位置信息用矩阵的形式表现,也就是下面的两段代码:
      for e in entities:
        (lo, hi) = e
        relative_position, _ = convert_entity_row(mapping, e, max_distance)
        sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
        sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
        if sub_lo1 == 0 and sub_hi1 == 0:
          continue
        if sub_lo2 == 0 and sub_hi2 == 0:
          continue
        # col
        res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
        mas[1:, sub_lo1:sub_hi2+1] = 1
    
      for e in entities:
        (lo, hi) = e
        relative_position, _ = convert_entity_row(mapping, e, max_distance)
        sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
        sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
        if sub_lo1 == 0 and sub_hi1 == 0:
          continue
        if sub_lo2 == 0 and sub_hi2 == 0:
          continue
        # row
        res[sub_lo1:sub_hi2+1, :] = relative_position
        mas[sub_lo1:sub_hi2+1, 1:] = 1
    

    结果是这样的:

    [[0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 3 3 3 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 2 2 2 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
     [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
     [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
     [0 0 0 0 0 0 5 5 5 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 6 6 6 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     ...
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
    

    最后就是实体的掩码矩阵了:

      for idx, (e1,e2) in enumerate(locs):
        # e1
        (lo, hi) = e1
        _, mask = convert_entity_row(mapping, e1, max_distance)
        e1_mas[idx] = mask
        # e2
        (lo, hi) = e2
        _, mask = convert_entity_row(mapping, e2, max_distance)
        e2_mas[idx] = mask
    

    结果:

    [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
     [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
     [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     ...
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
    
    • 回到convert_single_example函数中来:
      label_id = [label_map[label] for label in example.labels]
      label_id = label_id + [0] * (FLAGS.max_num_relations - len(label_id))
      cls_mask = [1] * example.num_relations + [0] * (FLAGS.max_num_relations - example.num_relations)
    

    这里定义了一个最大关系数量:12。先看结果:

    labels: 5 5 2 0 0 0 0 0 0 0 0 0
    cls_mask:[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    

    也就是一句话中的句子有多种关系的实体。
    最终将这些信息包装为InputFeatures类并返回。

    • 回到file_based_convert_examples_to_features函数:
    def file_based_convert_examples_to_features(
        examples, label_list, max_seq_length, tokenizer, output_file):
      """Convert a set of `InputExample`s to a TFRecord file."""
    
      writer = tf.python_io.TFRecordWriter(output_file)
    
      for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
          tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
    
        feature = convert_single_example(ex_index, example, label_list,
                                         max_seq_length, tokenizer)
    
        def create_int_feature(values):
          f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
          return f
    
        features = collections.OrderedDict()
        features["input_ids"] = create_int_feature(feature.input_ids)
        features["input_mask"] = create_int_feature(feature.input_mask)
        features["segment_ids"] = create_int_feature(feature.segment_ids)
        features["loc"] = create_int_feature(feature.loc)
        features["mas"] = create_int_feature(feature.mas)
        features["e1_mas"] = create_int_feature(feature.e1_mas)
        features["e2_mas"] = create_int_feature(feature.e2_mas)
        features["cls_mask"] = create_int_feature(feature.cls_mask)
        features["label_ids"] = create_int_feature(feature.label_id)
    
        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        writer.write(tf_example.SerializeToString())
      writer.close()
    

    也没什么好说的,转换成tensorflow中训练所需的张量后存储起来就行了。
    至此,mre-in-one-pass的数据处理部分就完了。

    参考代码:https://sourcegraph.com/github.com/helloeve/mre-in-one-pass/-/blob/run_classifier.py#L550

  • 相关阅读:
    mybatis异常:org.apache.ibatis.builder.IncompleteElementException: Could not find parameter map com.sunyan.domain.User
    Markdown首行缩进和换行
    mybatis入门——mybatis的概述
    python2跟python3的区别
    码云与git
    Python入门(一)
    python环境搭建
    python简介
    计算机基础
    Typora、安装及使用
  • 原文地址:https://www.cnblogs.com/xiximayou/p/14554817.html
Copyright © 2011-2022 走看看