zoukankan      html  css  js  c++  java
  • 关于Tfrecord

    写入Tfrecord

            print("convert data into tfrecord:train
    ")
            out_file_train = "/home/huadong.wang/bo.yan/fudan_mtl/data/ace2005/bn_nw.train.tfrecord"
            writer = tf.python_io.TFRecordWriter(out_file_train)
    
            for i in tqdm(range(len(data_train))):
                record = tf.train.Example(features=tf.train.Features(feature={
                    'word_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_x[i].tostring()])),
                    'et_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et1[i].tostring()])),
                    'et_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et2[i].tostring()])),
                    'position_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),
                    'position_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),
                    'chunks': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_chunks[i].tostring()])),
                    'spath_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_spath[i].tostring()])),
                    'seq_len': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_x_len[i]])),
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.argmax(train_relation[i])])),
                    'task': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.int64(0)]))
                }))
                writer.write(record.SerializeToString())
            writer.close()
    

      

    解析tfrecord

    def _parse_tfexample(serialized_example):
      '''parse serialized tf.train.SequenceExample to tensors
      context features : label, task
      sequence features: sentence
      '''
      context_features={'label'    : tf.FixedLenFeature([], tf.int64),
                        'task'    : tf.FixedLenFeature([], tf.int64),
                        'seq_len': tf.FixedLenFeature([], tf.int64)}
      sequence_features={'word_ids': tf.FixedLenSequenceFeature([], tf.int64),
                         'et_ids1': tf.FixedLenSequenceFeature([], tf.int64),
                         'et_ids2': tf.FixedLenSequenceFeature([], tf.int64),
                         'position_ids1': tf.FixedLenSequenceFeature([], tf.int64),
                         'position_ids2': tf.FixedLenSequenceFeature([], tf.int64),
                         'chunks': tf.FixedLenSequenceFeature([], tf.int64),
                         'spath_ids': tf.FixedLenSequenceFeature([], tf.int64),
                         }
      context_dict, sequence_dict = tf.parse_single_sequence_example(
                          serialized_example,
                          context_features   = context_features,
                          sequence_features  = sequence_features)
    
      sentence = (sequence_dict['word_ids'],sequence_dict['et_ids1'],sequence_dict['et_ids2'],sequence_dict['position_ids1'],
                  sequence_dict['position_ids2'],sequence_dict['chunks'],sequence_dict['spath_ids'], context_dict['seq_len'])
    
      label = context_dict['label']
      task = context_dict['task']
    
      return task, label, sentence
    
    
    
    def read_tfrecord(epoch, batch_size):
      for dataset in DATASETS:
        train_record_file = os.path.join(OUT_DIR, dataset+'.train.tfrecord')
        test_record_file = os.path.join(OUT_DIR, dataset+'.test.tfrecord')
    
        train_data = util.read_tfrecord(train_record_file, 
                                        epoch, 
                                        batch_size, 
                                        _parse_tfexample, 
                                        shuffle=True)
    
        test_data = util.read_tfrecord(test_record_file, 
                                        epoch,
                                       batch_size,
                                        _parse_tfexample, 
                                        shuffle=False)
        yield train_data, test_data
    

    模型中使用:

      def build_task_graph(self, data):
        task_label, labels, sentence = data
        # sentence = tf.nn.embedding_lookup(self.word_embed, sentence)
    ##########################
        word_ids, et_ids1,et_ids2,position_ids1,position_ids2,chunks,spath_ids,seq_len = sentence
        # sentence = word_ids
    #########################
    
        self.word_ids = word_ids
        self.position_ids1 = position_ids1
        self.position_ids2 = position_ids2
        self.et_ids1 = et_ids1
        self.et_ids2 = et_ids2
        self.chunks_ids = chunks
        self.spath_ids = spath_ids
        self.seq_len = seq_len
    
        sentence = self.add_embedding_layers()
    

      

     

  • 相关阅读:
    性能监控(5)–JAVA下的jstat命令
    内存分析工具-MAT(Memory Analyzer Tool)
    性能监控(4)–linux下的pidstat命令
    性能监控(3)–linux下的iostat命令
    性能监控(2)–linux下的vmstat命令
    性能监控(1)--linux下的top命令
    了解java虚拟机—在TALB上分配对象(10)
    了解java虚拟机—G1回收器(9)
    js 长按鼠标左键实现溢出内容左右滚动滚动
    html标签设置contenteditable时,去除粘贴文本自带样式
  • 原文地址:https://www.cnblogs.com/huadongw/p/11483730.html
Copyright © 2011-2022 走看看