zoukankan      html  css  js  c++  java
  • NMI计算

    NMI计算

    NMI(Normalized Mutual Information)标准化互信息,常用在聚类中,度量两个聚类结果的相近程度。是社区发现(community detection)的重要衡量指标,基本可以比较客观地评价出一个社区划分与标准划分之间相比的准确度。NMI的值域是0到1,越高代表划分得越准。

    # -*- coding:utf-8 -*-
    '''
    Created on 2017年10月28日
    
    @summary: 利用Python实现NMI计算
    
    @author: dreamhome
    '''
    import math
    import numpy as np
    from sklearn import metrics
    def NMI(A,B):
        #样本点数
        total = len(A)
        A_ids = set(A)
        B_ids = set(B)
        #互信息计算
        MI = 0
        eps = 1.4e-45
        for idA in A_ids:
            for idB in B_ids:
                idAOccur = np.where(A==idA)
                idBOccur = np.where(B==idB)
                idABOccur = np.intersect1d(idAOccur,idBOccur)
                px = 1.0*len(idAOccur[0])/total
                py = 1.0*len(idBOccur[0])/total
                pxy = 1.0*len(idABOccur)/total
                MI = MI + pxy*math.log(pxy/(px*py)+eps,2)
        # 标准化互信息
        Hx = 0
        for idA in A_ids:
            idAOccurCount = 1.0*len(np.where(A==idA)[0])
            Hx = Hx - (idAOccurCount/total)*math.log(idAOccurCount/total+eps,2)
        Hy = 0
        for idB in B_ids:
            idBOccurCount = 1.0*len(np.where(B==idB)[0])
            Hy = Hy - (idBOccurCount/total)*math.log(idBOccurCount/total+eps,2)
        MIhat = 2.0*MI/(Hx+Hy)
        return MIhat
    
    if __name__ == '__main__':
        A = np.array([1,1,1,1,1,1,2,2,2,2,2,2,3,3,3,3,3])
        B = np.array([1,2,1,1,1,1,1,2,2,2,2,3,1,1,3,3,3])
        print NMI(A,B)
        print metrics.normalized_mutual_info_score(A,B)
    
    
    原文:https://blog.csdn.net/DreamHome_S/article/details/78379635 
    View Code
    # coding=utf-8
    import numpy as np
    import math
    def NMI(A,B):
        # len(A) should be equal to len(B)
        total = len(A)
        A_ids = set(A)
        B_ids = set(B)
        #Mutual information
        MI = 0
        eps = 1.4e-45
        for idA in A_ids:
            for idB in B_ids:
                idAOccur = np.where(A==idA)
                idBOccur = np.where(B==idB)
                idABOccur = np.intersect1d(idAOccur,idBOccur)
                px = 1.0*len(idAOccur[0])/total
                py = 1.0*len(idBOccur[0])/total
                pxy = 1.0*len(idABOccur)/total
                MI = MI + pxy*math.log(pxy/(px*py)+eps,2)
        # Normalized Mutual information
        Hx = 0
        for idA in A_ids:
            idAOccurCount = 1.0*len(np.where(A==idA)[0])
            Hx = Hx - (idAOccurCount/total)*math.log(idAOccurCount/total+eps,2)
        Hy = 0
        for idB in B_ids:
            idBOccurCount = 1.0*len(np.where(B==idB)[0])
            Hy = Hy - (idBOccurCount/total)*math.log(idBOccurCount/total+eps,2)
        MIhat = 2.0*MI/(Hx+Hy)
        return MIhat
    
    if __name__ == '__main__':
        A = np.array([1,1,1,1,1,1,2,2,2,2,2,2,3,3,3,3,3])
        B = np.array([1,2,1,1,1,1,1,2,2,2,2,3,1,1,3,3,3])
        print (NMI(A,B))
    网上找到的代码

    结果:0.36456

    这一篇博文写的不错

    自己编写了一个,同时做了排序处理

    # coding=utf-8
    import numpy as np
    import math
    import operator
    
    
    def NMI(A,B):
        # len(A) should be equal to len(B)
        total = len(A)
        A_ids = set(A)
        B_ids = set(B)
        #Mutual information
        MI = 0
        eps = 1.4e-45
        for idA in A_ids:
            for idB in B_ids:
                idAOccur = np.where(A==idA)
                idBOccur = np.where(B==idB)
                idABOccur = np.intersect1d(idAOccur,idBOccur)
                px = 1.0*len(idAOccur[0])/total
                py = 1.0*len(idBOccur[0])/total
                pxy = 1.0*len(idABOccur)/total
                MI = MI + pxy*math.log(pxy/(px*py)+eps,2)
        # Normalized Mutual information
        Hx = 0
        for idA in A_ids:
            idAOccurCount = 1.0*len(np.where(A==idA)[0])
            Hx = Hx - (idAOccurCount/total)*math.log(idAOccurCount/total+eps,2)
        Hy = 0
        for idB in B_ids:
            idBOccurCount = 1.0*len(np.where(B==idB)[0])
            Hy = Hy - (idBOccurCount/total)*math.log(idBOccurCount/total+eps,2)
        MIhat = 2.0*MI/(Hx+Hy)
        return MIhat
    
    
    if __name__ == '__main__':
        A = np.array([1,1,1])
        B = np.array([2,3,4])
        C = np.array([1,1,6])
        print(NMI(A,B))
        m=[]#包含了位置的互信息
        n=[]#只有互信息
        dic={}
        q=1
        m.append(NMI(A,B))
        m.append(NMI(B,C))
        m.append(NMI(A,C))
        
        
        for i in m:
            dic['第{}个互信息'.format(q)]='{}'.format(i)
            q=q+1
        print(dic)
        rankdata=sorted(dic.items(),key=operator.itemgetter(1),reverse=True)
        print(rankdata)
        
        

    实验结果如图

  • 相关阅读:
    qs.stringify() 和JSON.stringify()的区别 飞鸟和蝉
    js随机数, 范围随机数 飞鸟和蝉
    VUECLI 4的跨域解决方案
    vue3elementadmin
    English dedicate 致力 题献
    解决java web项目导入后出现的问题 cannot be read or is not a valid ZIP file
    SQL 优化原则(转)
    Java Spring Error : Bean property '*****' is not writable or has an invalid setter method.
    c++面试题:#define MIN(A,B) ( (A) <= (B) ? (A) : (B) )
    freemark list 循环变量类型错误问题
  • 原文地址:https://www.cnblogs.com/xingnie/p/10334897.html
Copyright © 2011-2022 走看看