写入Tfrecord
print("convert data into tfrecord:train ") out_file_train = "/home/huadong.wang/bo.yan/fudan_mtl/data/ace2005/bn_nw.train.tfrecord" writer = tf.python_io.TFRecordWriter(out_file_train) for i in tqdm(range(len(data_train))): record = tf.train.Example(features=tf.train.Features(feature={ 'word_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_x[i].tostring()])), 'et_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et1[i].tostring()])), 'et_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et2[i].tostring()])), 'position_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])), 'position_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])), 'chunks': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_chunks[i].tostring()])), 'spath_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_spath[i].tostring()])), 'seq_len': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_x_len[i]])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.argmax(train_relation[i])])), 'task': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.int64(0)])) })) writer.write(record.SerializeToString()) writer.close()
解析tfrecord
def _parse_tfexample(serialized_example): '''parse serialized tf.train.SequenceExample to tensors context features : label, task sequence features: sentence ''' context_features={'label' : tf.FixedLenFeature([], tf.int64), 'task' : tf.FixedLenFeature([], tf.int64), 'seq_len': tf.FixedLenFeature([], tf.int64)} sequence_features={'word_ids': tf.FixedLenSequenceFeature([], tf.int64), 'et_ids1': tf.FixedLenSequenceFeature([], tf.int64), 'et_ids2': tf.FixedLenSequenceFeature([], tf.int64), 'position_ids1': tf.FixedLenSequenceFeature([], tf.int64), 'position_ids2': tf.FixedLenSequenceFeature([], tf.int64), 'chunks': tf.FixedLenSequenceFeature([], tf.int64), 'spath_ids': tf.FixedLenSequenceFeature([], tf.int64), } context_dict, sequence_dict = tf.parse_single_sequence_example( serialized_example, context_features = context_features, sequence_features = sequence_features) sentence = (sequence_dict['word_ids'],sequence_dict['et_ids1'],sequence_dict['et_ids2'],sequence_dict['position_ids1'], sequence_dict['position_ids2'],sequence_dict['chunks'],sequence_dict['spath_ids'], context_dict['seq_len']) label = context_dict['label'] task = context_dict['task'] return task, label, sentence def read_tfrecord(epoch, batch_size): for dataset in DATASETS: train_record_file = os.path.join(OUT_DIR, dataset+'.train.tfrecord') test_record_file = os.path.join(OUT_DIR, dataset+'.test.tfrecord') train_data = util.read_tfrecord(train_record_file, epoch, batch_size, _parse_tfexample, shuffle=True) test_data = util.read_tfrecord(test_record_file, epoch, batch_size, _parse_tfexample, shuffle=False) yield train_data, test_data
模型中使用:
def build_task_graph(self, data): task_label, labels, sentence = data # sentence = tf.nn.embedding_lookup(self.word_embed, sentence) ########################## word_ids, et_ids1,et_ids2,position_ids1,position_ids2,chunks,spath_ids,seq_len = sentence # sentence = word_ids ######################### self.word_ids = word_ids self.position_ids1 = position_ids1 self.position_ids2 = position_ids2 self.et_ids1 = et_ids1 self.et_ids2 = et_ids2 self.chunks_ids = chunks self.spath_ids = spath_ids self.seq_len = seq_len sentence = self.add_embedding_layers()