zoukankan      html  css  js  c++  java
  • seaborn画热力图注意的几点问题

    最近在使用注意力机制实现文本分类,我们需要观察每一个样本中,模型的重心放在哪里了,就是观察到权重最大的token。这时我们需要使用热力图进行可视化。

    我这里用到:seaborn

    seaborn.heatmap

    seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', annotkws=None, linewidths=0, linecolor='white', cbar=True, cbarkws=None, cbar_ax=None, square=False, ax=None, xticklabels=True, yticklabels=True, mask=None, **kwargs)

    • data:矩阵数据集,可以使numpy的数组(array),如果是pandas的dataframe,则df的index/column信息会分别对应到heatmap的columns和rows
    • linewidths,热力图矩阵之间的间隔大小
    • vmax,vmin, 图例中最大值和最小值的显示值,没有该参数时默认不显示

    data就是我们注意力矩阵的数据。注意,由于注意力的整理数值都偏小,直接使用数据显示的效果难以区分,我们可以将其放大100倍后来获取更加的效果。 先上代码吧!

    fr = open('./pkl/attention_matrix.pkl', 'rb')
    tokens, attention = pickle.load(fr)
    plt.figure(figsize=(30,20))
    sns.heatmap(attention, vamx=100, vmin=0)
    plt.savefig('./log/attention_matrix.png')
    
    # 获取数据
    import heapq
    check_file = './log/check_attention_keywords.txt'
    clean(check_file)
    fw = open(check_file, 'a', encoding='utf8')
    for t, a in zip(tokens, attention):
        temp = []
        max_num_index_list = map(list(a).index, heapq.nlargest(5, list(a))
        for index in max_num_index_list:
            word = t[index]
            print(word)
            temp.append(word)
        fw.write(str(temp)+'
    ')

      我这里取出注意力值最大的前5个词拿出来看的

  • 相关阅读:
    关系型数据库范式 沧海
    面试注意事项 沧海
    怎样在面试后得到想要的职位 沧海
    应届大学毕业生面试应答 沧海
    二叉树的遍历及实现 沧海
    比较好的C++面试题 沧海
    多态 沧海
    应届大学毕业生面试应答 沧海
    SQL Server开发人员应聘常被问的问题 沧海
    面试成功的技巧与忠告 沧海
  • 原文地址:https://www.cnblogs.com/demo-deng/p/10375408.html
Copyright © 2011-2022 走看看