zoukankan      html  css  js  c++  java
  • 如何可视化深度学习网络中Attention层

    前言

    在训练深度学习模型时,常想一窥网络结构中的attention层权重分布,观察序列输入的哪些词或者词组合是网络比较care的。在小论文中主要研究了关于词性POS对输入序列的注意力机制。同时对比实验采取的是words的self-attention机制。

    基于POS-Attention的层次化模型

    效果

    下图主要包含两列:word_attention是self-attention机制的模型训练结果,POS_attention是词性模型的训练结果。
    可以看出,相对于word_attention,POS的注意力机制不仅能够捕捉到评价的aspect,也能根据aspect关联的词借助情感语义表达的词性分布,care到相关词性的情感词。

    Attention可视化对比结果

    核心代码

    可视化样例

    # coding: utf-8
    def highlight(word, attn):
        html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
        return '<span style="background-color: {}">{}</span>'.format(html_color, word)
    
    def mk_html(seq, attns):
        html = ""
        for ix, attn in zip(seq, attns):
            html += ' ' + highlight(
                ix,
                attn
            )
        return html + "<br>"
    
    from IPython.display import HTML, display
    batch_size = 1
    seqs = [["这", "是", "一个", "测试", "样例", "而已"]]
    attns = [[0.01, 0.19, 0.12, 0.7, 0.2, 0.1]]
    
    for i in range(batch_size):
        text = mk_html(seqs[i], attns[i])
        display(HTML(text))
    

    接入model

    需要在model的返回列表中,添加attention_weight的输出,理论上维度应该和输入序列的长度是一致的。

    # load model
    import torch
    # if you train on gpu, you need to move onto cpu
    model = torch.load("../docs/model_chk/2018-11-07-02:45:37", map_location=lambda storage, location: storage)
    
    from torch.autograd import Variable
    for batch_idx, samples in enumerate(test_loader, 0):
        v_word = Variable(samples['word_vec'])
        v_final_label = samples['top_label']
    
        model.eval()
        final_probs, att_weight = model(v_word, v_pos)
    
        batch_words = toWords(samples["word_vec"].numpy(), idx_word)  # id转化为word
        batch_att = getAtten(batch_words, att_weight.data.numpy())    # 去除padding词,根据words的长度截取attention
        labels = toLabel(samples['top_label'].numpy())  # 真实标签
        pre_labels = toLabel(final_probs.data.numpy() >= 0.5)   # 预测标签
    
        for i in range(len(batch_words)):
            text = mk_html(batch_words[i], batch_att[i])
            print(labels[i], pre_labels[i])
            display(HTML(text))
    

    总结

    • 建议把可视化独立出来,用jupyter-notebook编辑,方便分段调试和copy;同时因为是借助html渲染的,所以需要notebook
    • 项目代码我后期后同步到github上,欢迎一起交流
  • 相关阅读:
    为什么大多Virtual Globe程序纵向旋转效率比较低
    惠普卖印刷服务 GIS卖什么?
    OpenLayers的新功能:矢量支持
    Google部分开源GMap API
    为OpenLayers 2.3添加Overview窗口
    从Grid控件到GIS软件
    GIS(数据)浏览器的点点滴滴
    ArcGIS 9.3和ArcGIS 10,一点感想
    关注:Pitney Bowes以4.08亿美金收购Mapinfo
    ArcGIS Server安装的几个问题
  • 原文地址:https://www.cnblogs.com/CocoML/p/12726004.html
Copyright © 2011-2022 走看看