zoukankan      html  css  js  c++  java
  • 基于tensor2tensor的注意力可视化

    根据训练好的Transformer模型,得到注意力矩阵,并对注意力进行可视化

    首先安装:tensorflow 1.13.1 + tensor2tensor 1.13.1

    可视化,请在Jupyter notebook中运行。该代码根据tensor2tensor/tensor2tensor/visualization/visualization.py修改得到

    # coding=utf-8
    # Copyright 2020 The Tensor2Tensor Authors.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    """Shared code for visualizing transformer attentions."""
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import numpy as np
    
    # To register the hparams set
    from tensor2tensor import models  # pylint: disable=unused-import
    from tensor2tensor import problems
    from tensor2tensor.utils import registry
    from tensor2tensor.utils import trainer_lib
    
    import tensorflow.compat.v1 as tf
    from tensor2tensor.utils import usr_dir
    EOS_ID = 1
    
    class AttentionVisualizer2(object):
      """Helper object for creating Attention visualizations."""
    
      def __init__(
          self, hparams_set,hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
        inputs, targets, samples, att_mats = build_model(
            hparams_set,hparams, t2t_usr_dir, model_name, data_dir, problem_name, beam_size=beam_size)
    
        # Fetch the problem
        ende_problem = problems.problem(problem_name)
        encoders = ende_problem.feature_encoders(data_dir)
    
        self.inputs = inputs
        self.targets = targets
        self.att_mats = att_mats
        self.samples = samples
        self.encoders = encoders
    
      def encode(self, input_str):
        """Input str to features dict, ready for inference."""
        inputs = self.encoders["inputs"].encode(input_str) + [EOS_ID]
        batch_inputs = np.reshape(inputs, [1, -1, 1, 1])  # Make it 3D.
        return batch_inputs
    
      def decode(self, integers):
        """List of ints to str."""
        integers = list(np.squeeze(integers))
        return self.encoders["targets"].decode(integers)
    
      def encode_list(self, integers):
        """List of ints to list of str."""
        integers = list(np.squeeze(integers))
        return self.encoders["inputs"].decode_list(integers)
    
      def decode_list(self, integers):
        """List of ints to list of str."""
        integers = list(np.squeeze(integers))
        return self.encoders["targets"].decode_list(integers)
    
      def get_vis_data_from_string(self, sess, input_string):
        """Constructs the data needed for visualizing attentions.
        Args:
          sess: A tf.Session object.
          input_string: The input sentence to be translated and visualized.
        Returns:
          Tuple of (
              output_string: The translated sentence.
              input_list: Tokenized input sentence.
              output_list: Tokenized translation.
              att_mats: Tuple of attention matrices; (
                  enc_atts: Encoder self attention weights.
                    A list of `num_layers` numpy arrays of size
                    (batch_size, num_heads, inp_len, inp_len)
                  dec_atts: Decoder self attention weights.
                    A list of `num_layers` numpy arrays of size
                    (batch_size, num_heads, out_len, out_len)
                  encdec_atts: Encoder-Decoder attention weights.
                    A list of `num_layers` numpy arrays of size
                    (batch_size, num_heads, out_len, inp_len)
              )
        """
        encoded_inputs = self.encode(input_string)
    
        # Run inference graph to get the translation.
        out = sess.run(self.samples, {
            self.inputs: encoded_inputs,
        })
    
    
    
        # Run the decoded translation through the training graph to get the
        # attention tensors.
    
    
        att_mats = sess.run(self.att_mats, {
            self.inputs: encoded_inputs,
            self.targets: np.reshape(out, [1, -1, 1, 1]),
        })
    
        output_string = self.decode(out)
        input_list = self.encode_list(encoded_inputs)
        output_list = self.decode_list(out)
    
        return output_string, input_list, output_list, att_mats
    
    
    def build_model(hparams_set, hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
      """Build the graph required to fetch the attention weights.
      Args:
        hparams_set: HParams set to build the model with.
        model_name: Name of model.
        data_dir: Path to directory containing training data.
        problem_name: Name of problem.
        beam_size: (Optional) Number of beams to use when decoding a translation.
            If set to 1 (default) then greedy decoding is used.
      Returns:
        Tuple of (
            inputs: Input placeholder to feed in ids to be translated.
            targets: Targets placeholder to feed to translation when fetching
                attention weights.
            samples: Tensor representing the ids of the translation.
            att_mats: Tensors representing the attention weights.
        )
      """
      print(model_name)
      usr_dir.import_usr_dir(t2t_usr_dir)
      hparams = trainer_lib.create_hparams(
          hparams_set,hparams, data_dir=data_dir, problem_name=problem_name)
    
      # print(hparams)
    
      translate_model = registry.model(model_name)(
          hparams, tf.estimator.ModeKeys.EVAL)
    
      inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
      targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets")
      translate_model({
          "inputs": inputs,
          "targets": targets,
      })
    
      # Must be called after building the training graph, so that the dict will
      # have been filled with the attention tensors. BUT before creating the
      # inference graph otherwise the dict will be filled with tensors from
      # inside a tf.while_loop from decoding and are marked unfetchable.
      atts = get_att_mats(translate_model,model_name)
    
      with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        samples = translate_model.infer({
            "inputs": inputs,
        }, beam_size=beam_size)["outputs"]
    
      return inputs, targets, samples, atts
    
    
    def get_att_mats(translate_model,model_name):
      """Get's the tensors representing the attentions from a build model.
      The attentions are stored in a dict on the Transformer object while building
      the graph.
      Args:
        translate_model: Transformer object to fetch the attention weights from.
      Returns:
      Tuple of attention matrices; (
          enc_atts: Encoder self attention weights.
            A list of `num_layers` numpy arrays of size
            (batch_size, num_heads, inp_len, inp_len)
          dec_atts: Decoder self attetnion weights.
            A list of `num_layers` numpy arrays of size
            (batch_size, num_heads, out_len, out_len)
          encdec_atts: Encoder-Decoder attention weights.
            A list of `num_layers` numpy arrays of size
            (batch_size, num_heads, out_len, inp_len)
      )
      """
      enc_atts = []
      dec_atts = []
      encdec_atts = []
    
      prefix = "%s/body/"%(model_name)
      postfix_self_attention = "/multihead_attention/dot_product_attention"
      if translate_model.hparams.self_attention_type == "dot_product_relative":
        postfix_self_attention = ("/multihead_attention/"
                                  "dot_product_attention_relative")
      postfix_encdec = "/multihead_attention/dot_product_attention"
    
      for i in range(translate_model.hparams.num_hidden_layers):
        enc_att = translate_model.attention_weights[
            "%sencoder/layer_%i/self_attention%s"
            % (prefix, i, postfix_self_attention)]
        dec_att = translate_model.attention_weights[
            "%sdecoder/layer_%i/self_attention%s"
            % (prefix, i, postfix_self_attention)]
        encdec_att = translate_model.attention_weights[
            "%sdecoder/layer_%i/encdec_attention%s" % (prefix, i, postfix_encdec)]
        enc_atts.append(enc_att)
        dec_atts.append(dec_att)
        encdec_atts.append(encdec_att)
    
      return enc_atts, dec_atts, encdec_atts
    
    
    
    from IPython.display import display
    def call_html():
      import IPython
      display(IPython.core.display.HTML('''
            <script src="/static/components/requirejs/require.js"></script>
            <script>
              requirejs.config({
                paths: {
                  base: '/static/base',
                  "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
                  jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
                },
              });
            </script>
            '''))
    
    
    import os
    from tensor2tensor import problems
    from tensor2tensor.bin import t2t_decoder  # To register the hparams set
    # from tensor2tensor.utils import registry
    from tensor2tensor.utils import trainer_lib
    from tensor2tensor.visualization import attention
    # from src.visualization import visualization
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    
    # HParams
    problem_name = 'translate_ende_wmt32k' #数据
    data_dir = os.path.expanduser('/home/usrname/collaboration/t2t_data/%s'%(problem_name))  #数据路径
    model_name = "collaboration"  #模型名称
    hparams_set = "collaboration_base" #模型类型
    hparams = 'max_length=128,num_hidden_layers=6,usedegray=1.0,reuse_n=0'  #自定义参数 (根据自己需求)
    t2t_usr_dir = './src/' #用户自定义模型model的路径
    
    visualizer = AttentionVisualizer2(hparams_set,hparams, t2t_usr_dir,model_name, data_dir, problem_name, beam_size=1)
    
    tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')
    

    接着继续运行:

    saver = tf.train.Saver()
    with tf.Session() as sess:
      ckpt = 'averaged.ckpt-0'  #checkpoint路径
      print(ckpt)
      saver.restore(sess, ckpt)
    
    #可视化样本 # input_sentence = "It is in this spirit that a majority of American governments have passed new laws since 2009 making the registration or voting process more difficult." input_sentence = "The Law will never be perfect, but its application should be just - this is what we are missing, in my opinion." output_string, inp_text, out_text, att_mats = visualizer.get_vis_data_from_string(sess, input_sentence) print(output_string) call_html() attention.show(inp_text, out_text, *att_mats)

    可视化结果:

      

  • 相关阅读:
    Spring MVC3 + Ehcache 缓存实现
    DB2导入导出数据库数据
    JS、ActiveXObject、Scripting.FileSystemObject
    new ActiveXObject("Scripting.FileSystemObject") 时抛出异常 .
    各种浏览器的内核是什么
    Content-Type: application/vnd.ms-excel">
    常用jar包用途
    nutz的json视图
    Nutz中那些好用的工具类
    The JSP specification requires that an attribute name is preceded by whitespace
  • 原文地址:https://www.cnblogs.com/huadongw/p/14195355.html
Copyright © 2011-2022 走看看