zoukankan      html  css  js  c++  java
  • 【文本摘要项目】5性能提升之PGN模型

    背景

        经过前几篇文章的内容,基本跑通了整个文本摘要的基本流程。主要包括:文本预处理、基于注意力机制的seq2seq文本摘要生成、解码算法、模型生成结果评估等。因此,经过前面的操作,基本可以得到一个完整的文本摘要抽取的过程,本文的内容旨在对抽取过程进行进一步的提升。本文主要实现的是transformer和bert模型之前的一个较为经典的模型——Pointer Generate Network,其理论部分的内容已在其他文章交代,本文重在其代码实现部分。

    核心内容

        本文内容分别就数据加载、模型构建、模型训练等几个主要部分,在原来baseline的基础上进行修改,其中模型评估部分,和基于seq2seq模型的过程基本相同,因此不在叙述其具体实现。完整代码在本文最后会附上代码。

    整体流程

        整体流程在前面基于seq2seq和Attention的文本摘要模型中已做具体介绍,本文旨在利用新方法,对模型性能进行提升,因此整体架构基本不变,在此不在累述。只在大体流程的部分细节进行优化,例如:使用生成器进行数据加载、学习率及损失函数的改进、Attention计算时考虑mask等,具体将在代码中描述。

    基于generator的数据批量加载

    # 构造数据
    train_dataset, params['train_steps_per_epoch'] = batcher(vocab, params)
    valid_dataset, params['valid_steps_per_epoch'] = batcher(vocab, params)
    logger.info(f'Building the dataset for train/valid ...')
    

        其中,batcher为本次加载数据方法的不同之处。

    def batcher(vocab, params):
        dataset = batch_generator(example_generator,
                                  params,
                                  vocab,
                                  params['max_enc_len'],
                                  params['max_dec_len'],
                                  params['batch_size'],
                                  params['mode'],
                                  )
    
        dataset = dataset.prefetch(params['buffer_size'])
        steps_per_epoch = get_steps_per_epoch(params)
    
        return dataset, steps_per_epoch
    

        example_generator为数据加载时的具体生成器,可根据数据格式,进行具体编写。而batch_generator会根据example_generator的返回的结果,构造成符合之前tensorflow读取数据的dataset样式,因此如何在example_generator中依旧保持符合tensorflow的常用工具dataset就成为核心。整体流程和过去保持一致。具体如下:

    def batch_generator(generator, params, vocab, max_enc_len, max_dec_len, batch_size, mode):
        dataset = tf.data.Dataset.from_generator(lambda: generator(params,
                                                                   vocab,
                                                                   max_enc_len,
                                                                   max_dec_len,
                                                                   mode,
                                                                   # batch_size
                                                                   ),
                                                 output_types={
                                                     'enc_len': tf.int32,
                                                     'enc_input': tf.int32,
                                                     'enc_input_extend_vocab': tf.int32,
                                                     'article_oovs': tf.string,
                                                     'dec_input': tf.int32,
                                                     'target': tf.int32,
                                                     'dec_len': tf.int32,
                                                     'article': tf.string,
                                                     'abstract': tf.string,
                                                     'abstract_sents': tf.string,
                                                     'decoder_pad_mask': tf.int32,
                                                     'encoder_pad_mask': tf.int32},
                                                 output_shapes={
                                                     'enc_len': [],
                                                     'enc_input': [None],
                                                     'enc_input_extend_vocab': [None],
                                                     'article_oovs': [None],
                                                     'dec_input': [None],
                                                     'target': [None],
                                                     'dec_len': [],
                                                     'article': [],
                                                     'abstract': [],
                                                     'abstract_sents': [],
                                                     'decoder_pad_mask': [None],
                                                     'encoder_pad_mask': [None]})
    
        dataset = dataset.padded_batch(batch_size=batch_size,
                                       padded_shapes=({'enc_len': [],
                                                       'enc_input': [None],
                                                       'enc_input_extend_vocab': [None],
                                                       'article_oovs': [None],
                                                       'dec_input': [max_dec_len],
                                                       'target': [max_dec_len],
                                                       'dec_len': [],
                                                       'article': [],
                                                       'abstract': [],
                                                       'abstract_sents': [],
                                                       'decoder_pad_mask': [max_dec_len],
                                                       'encoder_pad_mask': [None]}),
                                       padding_values={'enc_len': -1,
                                                       'enc_input': vocab.word2index[vocab.PAD_TOKEN],
                                                       'enc_input_extend_vocab': vocab.word2index[vocab.PAD_TOKEN],
                                                       'article_oovs': b'',
                                                       'dec_input': vocab.word2index[vocab.PAD_TOKEN],
                                                       'target': vocab.word2index[vocab.PAD_TOKEN],
                                                       'dec_len': -1,
                                                       'article': b'',
                                                       'abstract': b'',
                                                       'abstract_sents': b'',
                                                       'decoder_pad_mask': 0,
                                                       'encoder_pad_mask': 0},
                                       drop_remainder=True)
    
        def update(entry):
            return ({
                        "enc_input": entry["enc_input"],
                        "extended_enc_input": entry["enc_input_extend_vocab"],
                        "article_oovs": entry["article_oovs"],
                        "enc_len": entry["enc_len"],
                        "article": entry["article"],
                        "max_oov_len": tf.shape(entry["article_oovs"])[1],
                        "encoder_pad_mask": entry["encoder_pad_mask"]
                    },
                    {
                        "dec_input": entry["dec_input"],
                        "dec_target": entry["target"],
                        "dec_len": entry["dec_len"],
                        "abstract": entry["abstract"],
                        "decoder_pad_mask": entry["decoder_pad_mask"]
                    })
    
        dataset = dataset.map(update)
    
        return dataset
    

        在batch_generator中,使用tf.data.Dataset.from_generator()接口,能使使得数据利用生成器一条一条进行生成。在该方法中需要指定输出形状output_shape和输出类型output_types参数。并且,可以在生成数据的同时,进行一定的预处理,padded_batch来对齐长度信息,并将长度不够的词汇自动按照指定的值进行填充。

    def example_generator(params, vocab, max_enc_len, max_dec_len, mode):
    
        if mode != 'test':
    
            dataset_x = tf.data.TextLineDataset(params[f'{mode}_seg_x_dir'])
            dataset_y = tf.data.TextLineDataset(params[f'{mode}_seg_y_dir'])
    
            train_dataset = tf.data.Dataset.zip((dataset_x, dataset_y)).take(count=10000)
    
            if mode == 'train':
                train_dataset = train_dataset.shuffle(10, reshuffle_each_iteration=True).repeat(1)
    
            for raw_record in train_dataset:
    
                start_decoding = vocab.word_to_index(vocab.START_DECODING)
                stop_decoding = vocab.word_to_index(vocab.STOP_DECODING)
    
                article = raw_record[0].numpy().decode('utf-8')
                article_words = article.split()[:max_enc_len]
    
                enc_input = [vocab.word_to_index(w) for w in article_words]
                enc_input_extend_vocab, article_oovs = article_to_index(article_words, vocab)
    
                # add start and stop flag
                enc_input = get_enc_inp_targ_seqs(enc_input,
                                                  max_enc_len,
                                                  start_decoding,
                                                  stop_decoding)
    
                enc_input_extend_vocab = get_enc_inp_targ_seqs(enc_input_extend_vocab,
                                                               max_enc_len,
                                                               start_decoding,
                                                               stop_decoding)
    
                # mark长度
                enc_len = len(enc_input)
                # 添加mark标记
                encoder_pad_mask = [1 for _ in range(enc_len)]
    
                abstract = raw_record[1].numpy().decode('utf-8')
                abstract_words = abstract.split()
                abs_ids = [vocab.word_to_index(w) for w in abstract_words]
    
                dec_input, target = get_dec_inp_targ_seqs(abs_ids,
                                                          max_dec_len,
                                                          start_decoding,
                                                          stop_decoding)
    
                if params['pointer_gen']:
                    abs_ids_extend_vocab = abstract_to_index(abstract_words, vocab, article_oovs)
                    _, target = get_dec_inp_targ_seqs(abs_ids_extend_vocab,
                                                      max_dec_len,
                                                      start_decoding,
                                                      stop_decoding)
                # mark长度
                dec_len = len(target)
                # 添加mark标记
                decoder_pad_mask = [1 for _ in range(dec_len)]
    
                output = {
                    "enc_len": enc_len,
                    "enc_input": enc_input,
                    "enc_input_extend_vocab": enc_input_extend_vocab,
                    "article_oovs": article_oovs,
                    "dec_input": dec_input,
                    "target": target,
                    "dec_len": dec_len,
                    "article": article,
                    "abstract": abstract,
                    "abstract_sents": abstract,
                    "decoder_pad_mask": decoder_pad_mask,
                    "encoder_pad_mask": encoder_pad_mask
                }
    
                yield output
        else:
            test_dataset = tf.data.TextLineDataset(params['valid_seg_x_dir'])
            for raw_record in test_dataset:
                article = raw_record.numpy().decode('utf-8')
                article_words = article.split()[: max_enc_len]
                enc_len = len(article_words)
    
                enc_input = [vocab.word_to_index(w) for w in article_words]
                enc_input_extend_vocab, article_oovs = article_to_index(article_words, vocab)
    
                # 添加mark标记
                encoder_pad_mask = [1 for _ in range(enc_len)]
    
                output = {
                    "enc_len": enc_len,
                    "enc_input": enc_input,
                    "enc_input_extend_vocab": enc_input_extend_vocab,
                    "article_oovs": article_oovs,
                    "dec_input": [],
                    "target": [],
                    "dec_len": params['max_dec_len'],
                    "article": article,
                    "abstract": '',
                    "abstract_sents": '',
                    "decoder_pad_mask": [],
                    "encoder_pad_mask": encoder_pad_mask
                }
                # 每一批的数据都一样阿, 是的是为了beam search
                if params["decode_mode"] == "beam":
                    for _ in range(params["batch_size"]):
                        yield output
                elif params["decode_mode"] == "greedy":
                    yield output
                else:
                    print("shit")
    

        在example_generator函数中,就是对具体数据的处理。其中包括:添加开始、结束标志()。而在PGN模型中,一个很重要的部分就是,PGN模型的复制(copy)能力。而PGN的复制能力,使得该模型具有一定的解决oov问题的能力,具体体现在哪里呢?主要是pointer network复制的词,来自输入数据(input text),所以一定程度上能得到出现在输入数据(input text),但是不在词汇表中的词。说到底,是在一定程度上利用了在构建词汇表时,过滤掉的低频词等。

        PGN模型中另一个重要的点在于,其最终预测的概率分布是 词汇表长度的概率分布 + 输出数据(input text)的attention 的分布。(个人理解,同时出现在 词汇表 和input text的词的概率会被增大;按照贪心解码取概率的最大的1个或者K个词的思路,得到原始词汇表以外的词概率较小,因此该模型只能在一定程度上缓解oov问题。)

        代码中实现上述两个内容的基础时在对数据进行处理时,函数article_to_index和abstract_to_index两个函数中,将原来标注为UNK的单词,重新标注为出现在输入数据中的词(即较少了UNK)。

    def article_to_index(article_words, vocab):
    
        oov_words = []
        extend_vocab_index = []
    
        unk_index = vocab.UNKNOWN_TOKEN_INDEX
    
        for word in article_words:
            word_index = vocab.word_to_index(word)
            if word_index == unk_index:
                if word not in oov_words:
                    oov_words.append(word)
    
                oov_num = oov_words.index(word)
                extend_vocab_index.append(vocab.size() + oov_num)
            else:
                extend_vocab_index.append(word_index)
    
        return extend_vocab_index, oov_words
    

        将原来的词汇表进行扩展的代码样例如上,这里只展示了对训练数据X的处理,对于y的处理类似,后续完整代码可见。

    数据保存部分的代码优化

    # 构造模型保存管理器
    checkpoint = tf.train.Checkpoint(PGN=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, params['checkpoint_dir'], max_to_keep=5)
    
    if checkpoint_manager.latest_checkpoint:
        checkpoint_manager.restore(checkpoint_manager.latest_checkpoint)
        params['trained_epoch'] = int(checkpoint_manager.latest_checkpoint[-1])
        logger.info(f'Building model by restoring {checkpoint_manager.latest_checkpoint}')
    else:
        params['trained_epoch'] = 1
        logger.info('Building model from initial ...')
    
    # 设置学习率
    params['learning_rate'] *= np.power(0.95, params['trained_epoch'])
    logger.info(f'Learning rate : {params["learning_rate"]}')
    

        上述代码的优化点在于对动态学习率的设置、自动加载上一次训练保存的最优模型,以及训练了多少epoch。

    PGN模型构建

    class PGN(keras.Model):
    
        def __init__(self, params):
            super(PGN, self).__init__()
            self.embedding_matrix = load_embedding_matrix(max_vocab_size=params['vocab_size'])
    
            self.vocab_size = params['vocab_size']
            self.batch_size = params['batch_size']
    
            self.encoder = Encoder(self.embedding_matrix,
                                   params['enc_units'],
                                   params['batch_size'])
    
            self.decoder = Decoder(self.embedding_matrix,
                                   params['dec_units'],
                                   params['batch_size'])
    
            self.pointer = Pointer()
    
        def call_one_step(self, dec_input, dec_hidden, enc_output, enc_pad_mask, use_coverage, prev_coverage):
            context_vector, dec_hidden, dec_x, prediction, attention_weights, coverage = self.decoder(dec_input,
                                                                                                      dec_hidden,
                                                                                                      enc_output,
                                                                                                      enc_pad_mask,
                                                                                                      prev_coverage,
                                                                                                      use_coverage)
    
            p_gens = self.pointer(context_vector, dec_hidden, dec_x)
    
            return prediction, dec_hidden, context_vector, attention_weights, p_gens, coverage
    
        def call(self, dec_input, dec_hidden, enc_output, enc_extended_input, batch_oov_len, enc_pad_mask, use_coverage,
                 coverage=None):
            predictions = []
            attentions = []
            p_gens = []
            coverages = []
    
            for t in range(dec_input.shape[1]):
                final_dists, dec_hidden, context_vector, attention_weights, p_gen, coverage = self.call_one_step(
                    dec_input[:, t],
                    dec_hidden,
                    enc_output,
                    enc_pad_mask,
                    use_coverage,
                    coverage)
    
                coverages.append(coverage)
                predictions.append(final_dists)
                attentions.append(attention_weights)
                p_gens.append(p_gen)
    
            final_dists = _calc_final_dist(enc_extended_input,
                                           predictions,
                                           attentions,
                                           p_gens,
                                           batch_oov_len,
                                           self.vocab_size,
                                           self.batch_size)
    
            attentions = tf.stack(attentions, axis=1)
    
            return tf.stack(final_dists, 1), attentions, tf.stack(coverage, 1)
    
    
    def _calc_final_dist(_enc_batch_extend_vocab, vocab_dists, attn_dists, p_gens, batch_oov_len, vocab_size, batch_size):
        """
        Calculate the final distribution, for the pointer-generator model
        Args:
        vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays.
                    The words are in the order they appear in the vocabulary file.
        attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays
        Returns:
        final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays.
        """
        # Multiply vocab dists by p_gen and attention dists by (1-p_gen)
        vocab_dists = [p_gen * dist for (p_gen, dist) in zip(p_gens, vocab_dists)]
        attn_dists = [(1 - p_gen) * dist for (p_gen, dist) in zip(p_gens, attn_dists)]
    
        # Concatenate some zeros to each vocabulary dist, to hold the probabilities for in-article OOV words
        extended_vsize = vocab_size + batch_oov_len  # the maximum (over the batch) size of the extended vocabulary
        extra_zeros = tf.zeros((batch_size, batch_oov_len))
        # list length max_dec_steps of shape (batch_size, extended_vsize)
        vocab_dists_extended = [tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists]
    
        # Project the values in the attention distributions onto the appropriate entries in the final distributions
        # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary,
        # then we add 0.1 onto the 500th entry of the final distribution
        # This is done for each decoder timestep.
        # This is fiddly; we use tf.scatter_nd to do the projection
        batch_nums = tf.range(0, limit=batch_size)  # shape (batch_size)
        batch_nums = tf.expand_dims(batch_nums, 1)  # shape (batch_size, 1)
        attn_len = tf.shape(_enc_batch_extend_vocab)[1]  # number of states we attend over
        batch_nums = tf.tile(batch_nums, [1, attn_len])  # shape (batch_size, attn_len)
        indices = tf.stack((batch_nums, _enc_batch_extend_vocab), axis=2)  # shape (batch_size, enc_t, 2)
        shape = [batch_size, extended_vsize]
        # list length max_dec_steps (batch_size, extended_vsize)
        attn_dists_projected = [tf.scatter_nd(indices, copy_dist, shape) for copy_dist in attn_dists]
    
        # Add the vocab distributions and the copy distributions together to get the final distributions
        # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving
        # the final distribution for that decoder timestep
        # Note that for decoder timesteps and examples corresponding to a [PAD] token, this is junk - ignore.
        final_dists = [vocab_dist + copy_dist for (vocab_dist, copy_dist) in
                       zip(vocab_dists_extended, attn_dists_projected)]
    
        return final_dists
    

        PGN模型的构成主要有以下几部分:encoder、decoder、gen_pointer。本文中的实现,采用Teacher Forcing进行单词的生成过程。在函数_calc_final_dist()中,用于计算最终扩展后的概率分布。call_one_step()函数用于call调用,每次生成一个解码的词。下面将继续介绍模型的几个组件。

    class Encoder(keras.Model):
    
        def __init__(self, embedding_matrix, enc_units, batch_size):
            super(Encoder, self).__init__()
    
            self.batch_size = batch_size
            self.enc_units = enc_units
            self.vocab_size, self.embedding_dim = embedding_matrix.shape
    
            self.embedding = keras.layers.Embedding(self.vocab_size,
                                                    self.embedding_dim,
                                                    weights=[embedding_matrix],
                                                    trainable=False)
    
            self.gru = keras.layers.GRU(self.enc_units,
                                        return_state=True,
                                        return_sequences=True,
                                        recurrent_initializer='glorot_uniform')
    
            self.bidirectional_gru = keras.layers.Bidirectional(self.gru)
    
        def call(self, x, enc_hidden):
    
            x = self.embedding(x)  # x shape: batch_size * enc_units -> batch_size * 128
    
            # enc_output shape: batch * max_len * enc_unit
            enc_output, forward_state, backward_state = self.bidirectional_gru(x, initial_state=[enc_hidden, enc_hidden])
    
            # enc_hidden shape: batch_size * 256
            enc_hidden = keras.layers.concatenate([forward_state, backward_state], axis=-1)
    
            return enc_output, enc_hidden
    
        def initialize_hidden_state(self):
    
            return tf.zeros(shape=(self.batch_size, self.enc_units))
    

        encoder部分和前面基于seq2seq模型基本相同,差别在于此处使用了一个双向的gru进行编码表示。

    def masked_attention(enc_pad_mask, attn_dist):
    
        attn_dist = tf.squeeze(attn_dist, axis=2)
        mask = tf.cast(enc_pad_mask, dtype=attn_dist.dtype)
    
        attn_dist *= mask
    
        mask_sum = tf.reduce_sum(attn_dist, axis=1)
        attn_dist = attn_dist / tf.reshape(mask_sum + 1e-12, [-1, 1])
    
        attn_dist = tf.expand_dims(attn_dist, axis=2)
    
        return attn_dist
    
    
    class BahdanauAttention(keras.layers.Layer):
        def __init__(self, units):
            super(BahdanauAttention, self).__init__()
    
            self.W_s = keras.layers.Dense(units)
            self.W_h = keras.layers.Dense(units)
            self.W_c = keras.layers.Dense(units)
            self.V = keras.layers.Dense(1)
    
        def call(self, dec_hidden, enc_output, enc_pad_mask, use_coverage=False, pre_coverage=None):
    
            hidden_with_time_axis = tf.expand_dims(dec_hidden, 1)
    
            if use_coverage and pre_coverage is not None:
                score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis) + self.W_c(pre_coverage)))
    
                attention_weights = tf.nn.softmax(score, axis=1)
                attention_weights = masked_attention(enc_pad_mask, attention_weights)
                coverage = attention_weights + pre_coverage
            else:
                score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis)))
    
                attention_weights = tf.nn.softmax(score)
                attention_weights = masked_attention(enc_pad_mask, attention_weights)
    
                if use_coverage:
                    coverage = attention_weights
                else:
                    coverage = []
    
            context_vactor = attention_weights * enc_output
            context_vactor = tf.reduce_sum(context_vactor, axis=1)
    
            return context_vactor, tf.squeeze(attention_weights, -1), coverage
    

        Attention的计算一般可分为三个步骤:1.计算Attenrion score 2. softmax 3. reduce_sum。此处Attention过程稍微和之前的实现有所不同。主要体现在:1.计算softmax时,考虑了mask部分。在masked_attention函数中计算attention的时候,被mask的部分不参与到计算。2.考虑前一次计算attention时得到的向量。通过指定use_coverage参数指定是否使用收敛机制。

    class Decoder(keras.Model):
    
        def __init__(self, embedding_matrix, dec_units, batch_size):
            super(Decoder, self).__init__()
            self.batch_size = batch_size
            self.dec_units = dec_units
            self.vocab_size, self.embedding_dim = embedding_matrix.shape
    
            self.embedding = keras.layers.Embedding(self.vocab_size,
                                                    self.embedding_dim,
                                                    weights=[embedding_matrix],
                                                    trainable=False)
            self.cell = keras.layers.GRUCell(units=self.dec_units, recurrent_initializer='glorot_uniform')
    
            self.fc = keras.layers.Dense(self.vocab_size, activation=keras.activations.softmax)
    
            self.attention = BahdanauAttention(self.dec_units)
    
        def call(self, dec_input, dec_hidden, enc_output, enc_pad_mask, pre_coverage, use_covarage=True):
    
            dec_x = self.embedding(dec_input)
    
            dec_output, [dec_hidden] = self.cell(dec_x, [dec_hidden])
    
            context_vector, attention_weights, coverage = self.attention(dec_hidden,
                                                                         enc_output,
                                                                         enc_pad_mask,
                                                                         use_covarage,
                                                                         pre_coverage)
    
            dec_output = tf.concat([dec_output, context_vector], axis=-1)
            prediction = self.fc(dec_output)
    
            return context_vector, dec_hidden, dec_x, prediction, attention_weights, coverage
    

        decoder的部分一个比较重要的点在于,解码时是根据时间步(timestpes)一步一步进行解码,上一个时间步的输出和attention向量,会参与到下一步attention向量的计算,因此此处采用cell级别的lstm实现。

    class Pointer(keras.layers.Layer):
    
        def __init__(self):
            super(Pointer, self).__init__()
    
            self.w_s_reduce = keras.layers.Dense(1)
            self.w_i_reduce = keras.layers.Dense(1)
            self.w_c_reduce = keras.layers.Dense(1)
    
        def call(self, context_vector, dec_hidden, dec_inp):
    
            return tf.nn.sigmoid(self.w_s_reduce(dec_hidden) +
                                 self.w_c_reduce(context_vector) +
                                 self.w_i_reduce(dec_inp))
    

        根据PGN模型的理论,在计算最终概率分布时,采用两个概率分布加权进行最终概率分布计算,而以多大概率进行加权呢?Pointer类主要用于计算这个值。到此为止,模型部分的细节基本介绍完毕。

    模型训练以及评估

    def train_model(model, train_dataset, valid_dataset, params, checkpoint_manager):
        epochs = params['epochs']
    
        optimizer = keras.optimizers.Adagrad(learning_rate=params['learning_rate'],
                                             initial_accumulator_value=params['adagrad_init_acc'],
                                             clipnorm=params['max_grad_norm'],
                                             epsilon=params['eps'])
    
        best_loss = 100
        for epoch in range(epochs):
            start = time.time()
            enc_hidden = model.encoder.initialize_hidden_state()
    
            total_loss = 0.
            total_log_loss = 0.
            total_cov_loss = 0.
            step = 0
            for encoder_batch_data, decoder_batch_data in train_dataset:
    
                batch_loss, log_loss, cov_loss = train_step(model,
                                                            enc_hidden,
                                                            encoder_batch_data['enc_input'],
                                                            encoder_batch_data['extended_enc_input'],
                                                            encoder_batch_data['max_oov_len'],
                                                            decoder_batch_data['dec_input'],
                                                            decoder_batch_data['dec_target'],
                                                            enc_pad_mask=encoder_batch_data['encoder_pad_mask'],
                                                            dec_pad_mask=decoder_batch_data['decoder_pad_mask'],
                                                            params=params,
                                                            optimizer=optimizer,
                                                            mode='train')
    
                step += 1
                total_loss += batch_loss
                total_log_loss += log_loss
                total_cov_loss += cov_loss
                if step % 50 == 0:
                    if params['use_coverage']:
    
                        print('Epoch {} Batch {} avg_loss {:.4f} log_loss {:.4f} cov_loss {:.4f}'.format(epoch + 1,
                                                                                                         step,
                                                                                                         total_loss / step,
                                                                                                         total_log_loss / step,
                                                                                                         total_cov_loss / step))
                    else:
                        print('Epoch {} Batch {} avg_loss {:.4f}'.format(epoch + 1,
                                                                         step,
                                                                         total_loss / step))
    
                valid_total_loss, valid_total_cov_loss, valic_total_log_loss = evaluate(model, valid_dataset, params)
                print('Epoch {} Loss {:.4f}, valid Loss {:.4f}'.format(epoch + 1, total_loss / step, valid_total_loss))
                print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
    
                if valid_total_loss < best_loss:
                    best_loss = valid_total_loss
                    ckpt_save_path = checkpoint_manager.save()
                    print('Saving checkpoint for epoch {} at {}, best valid loss {}'.format(epoch + 1,
                                                                                            ckpt_save_path,
                                                                                            best_loss))
    
    
    def train_step(model, enc_hidden, enc_input, extend_enc_input, max_oov_len, dec_input, dec_target, enc_pad_mask, dec_pad_mask, params, optimizer=None, mode='train'):
    
        with tf.GradientTape() as tape:
    
            # encoder,逐个预测
            enc_output, enc_hidden = model.encoder(enc_input, enc_hidden)
    
            # decoder
            dec_hidden = enc_hidden
            final_dists, attentions, coverages = model(dec_input, dec_hidden, enc_output, extend_enc_input, max_oov_len, enc_pad_mask=enc_pad_mask, use_coverage=params['use_coverage'], coverage=None)
    
            batch_loss, log_loss, cov_loss = calc_loss(dec_target, final_dists, dec_pad_mask, attentions, params['cov_loss_wt'], params['eps'])
    
            if mode == 'train':
                variables = (model.encoder.trainable_variables + model.decoder.trainable_variables + model.pointer.trainable_variables)
                gradients = tape.gradient(batch_loss, variables)
                optimizer.apply_gradients(zip(gradients, variables))
    
            return batch_loss, log_loss, cov_loss
    

        训练过程的整理框架基本维持不变。差别在于计算损失函数的计算。

    def calc_loss(real, pred, dec_mask, attentions, cov_loss_wt, eps):
    
        log_loss = pgn_log_loss_function(real, pred, dec_mask, eps)
    
        cov_loss = _coverage_loss(attentions, dec_mask)
    
        return log_loss + cov_loss_wt * cov_loss, log_loss, cov_loss
    

        损失函数的计算包括两个部分:一个是pgn模型原有的损失,一个是使用coverage机制时带来的损失。并将最终损失做一个加权。

    def pgn_log_loss_function(real, final_dists, padding_mask, eps):
    
        loss_per_step = []
        batch_nums = tf.range(0, limit=real.shape[0])
        final_dists = tf.transpose(final_dists, perm=[1, 0, 2])
        for dec_step, dist in enumerate(final_dists):
            targets = real[:, dec_step]
            indices = tf.stack((batch_nums, targets), axis=1)
            gold_probs = tf.gather_nd(dist, indices)
            losses = -tf.math.log(gold_probs + eps)
            loss_per_step.append(losses)
    
        _loss = _mask_and_avg(loss_per_step, padding_mask)
        return _loss
    
    def _coverage_loss(attn_dists, padding_mask):
    
        attn_dists = tf.transpose(attn_dists, perm=[1, 0, 2])
        coverage = tf.zeros_like(attn_dists[0])
    
        covlosses = []
        for a in attn_dists:
    
            covloss = tf.reduce_sum(tf.minimum(a, coverage), [1])
            covlosses.append(covloss)
    
            coverage += a
        coverage_loss = _mask_and_avg(covlosses, padding_mask)
        return coverage_loss
    
    
    def _mask_and_avg(values, padding_mask):
        padding_mask = tf.cast(padding_mask, dtype=values[0].dtype)
        dec_lens = tf.reduce_sum(padding_mask, axis=1)
        values_per_step = [v * padding_mask[:, dec_step] for dec_step, v in enumerate(values)]
        values_per_ex = s文本摘要-05-性能提升之PGN模型.mdum(values_per_step) / dec_lens
    
        return tf.reduce_mean(values_per_ex)
    

    完整代码

    代码

  • 相关阅读:
    #树#遍历#N叉树的前序遍历
    #树#递归#最大二叉树II
    #树#递归#二叉树的镜像
    #树#递归#最大二叉树
    #树#二叉搜索树的最近公共祖先
    #树#二叉树的直径
    #树#N叉树的后序遍历
    #树#判断平衡二叉树
    webpack+react+nodejs+express前端开发环境搭建
    sublime 玩转react+es6
  • 原文地址:https://www.cnblogs.com/miners/p/15072878.html
Copyright © 2011-2022 走看看