zoukankan      html  css  js  c++  java
  • 图像匹配 | NCC 归一化互相关损失 | 代码 + 讲解

    • 文章转载自:微信公众号「机器学习炼丹术」
    • 作者:炼丹兄(已授权)
    • 作者联系方式:微信cyx645016617(欢迎交流共同进步)

    本次的内容主要讲解NCCNormalized cross-correlation 归一化互相关。

    两张图片是否是同一个内容,现在深度学习的方案自然是用神经网络,比方说:孪生网络的架构做人面识别等等;

    在传统的非参数方法中,常见的也有相关系数等。我在上一片文章voxelmorph的模型的学习中发现,在医学图像配准任务(不限于医学),衡量两个图片相似的度量有一种叫做NCC的

    而这个NCC就是Normalized Cross-Correlation归一化互相关系数。

    1 互相关系数

    如果你知道互相关系数,那么你就能很好的理解归一化互相关系数。

    相关系数的计算公式如下:

    [r(X,Y) = frac{Cov(X,Y)}{sqrt{Var(X)Var(Y)}} ]

    公式中的X,Y分别表示两个图片,(Cov(X,Y))表示两个图片的协方差,(Var(X))表示X自身的方差;

    2 归一化互相关NCC

    如果把一张图片,按照一定的像素,比方说9x9的一个框滑动,那么就可以把图片分成很多的9x9的小图片,那么NCC就是X,Y两张大图片中的对应的小图片的互相关系数的平均值。

    这里看一下协方差的计算方式:
    (Cov(X,Y) = E[(X-E(X))(Y-E(Y))])

    方差的计算为:
    (Var(X) = E[(X-E(X))^2])

    其实NCC不难理解,但是如何用代码计算呢?当然我们可以一行一行遍历求解,但是这样时间复杂度过高,所以我们做好还是选择矩阵运算。

    3 NCC损失函数的代码

    class NCC:
        """
        Local (over window) normalized cross correlation loss.
        """
    
        def __init__(self, win=None):
            self.win = win
    
        def loss(self, y_true, y_pred):
    
            I = y_true
            J = y_pred
    
            # get dimension of volume
            # assumes I, J are sized [batch_size, *vol_shape, nb_feats]
            ndims = len(list(I.size())) - 2
            assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims
    
            # set window size
            win = [9] * ndims if self.win is None else self.win
    
            # compute filters
            sum_filt = torch.ones([1, 1, *win]).to("cuda")
    
            pad_no = math.floor(win[0]/2)
    
            if ndims == 1:
                stride = (1)
                padding = (pad_no)
            elif ndims == 2:
                stride = (1,1)
                padding = (pad_no, pad_no)
            else:
                stride = (1,1,1)
                padding = (pad_no, pad_no, pad_no)
    
            # get convolution function
            conv_fn = getattr(F, 'conv%dd' % ndims)
    
            # compute CC squares
            I2 = I * I
            J2 = J * J
            IJ = I * J
    
            I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding)
            J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding)
            I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
            J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
            IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)
    
            win_size = np.prod(win)
            u_I = I_sum / win_size
            u_J = J_sum / win_size
    
            cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
            I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
            J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size
    
            cc = cross * cross / (I_var * J_var + 1e-5)
    
            return -torch.mean(cc)
    

    这段代码其实不是很好看懂,我思考了很久才明白。其中的关键就在于如何理解:

    # compute CC squares
            I2 = I * I
            J2 = J * J
            IJ = I * J
    
            I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding)
            J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding)
            I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
            J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
            IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)
    
            win_size = np.prod(win)
            u_I = I_sum / win_size
            u_J = J_sum / win_size
    
            cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
            I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
            J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size
    

    我们可以才到,这个cross应该是协方差部分,I_var和J_var是方差部分。

    我们对协方差公式进行推导:(Cov(X,Y) = E[(X-E(X))(Y-E(Y))])
    (=E[XY-XE(Y)-YE(X)+E(X)E(Y)])

    这样刚好和cross对应上。

    • IJ_sum = E[XY]
    • u_J * I_sum = E[XE(Y)]
    • u_I * u_J * win_size = E[E(X)E(Y)]

    对方差公式进行推导:(Var(X) = E[(X-E(X))^2]=E[X^2-2XE(X)+E(X)^2])

    • J2_sum = E(X^2)
    • 2 * u_J * J_sum = E[2XE(X)]
    • u_J * u_J * win_size = E[E(X)^2]
    人不可傲慢。
  • 相关阅读:
    javascript 下拉列表 自动取值 无需value
    gopkg:一种方便的go pakcage管理方式
    编译GDAL使用最新的HDF库配置文件
    leetcode:程序员面试技巧
    【Unity Shader实战】卡通风格的Shader(一)
    GDAL1.11版本对SHP文件索引加速测试
    【Unity Shaders】Shader学习资源和Surface Shader概述
    关于rowid的函数
    java基本类型的大小
    【转载】oracle之rowid详解
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/14541639.html
Copyright © 2011-2022 走看看