zoukankan      html  css  js  c++  java
  • 『论文笔记』SuperGlue

    https://zhuanlan.zhihu.com/p/342105673

    特征处理部分比较好理解,点的self、cross注意力机制实现建议看下源码(MultiHeadedAttention),

    def attention(query, key, value):
        dim = query.shape[1]
        scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
        prob = torch.nn.functional.softmax(scores, dim=-1)
        return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
    
    
    class MultiHeadedAttention(nn.Module):
        """ Multi-head attention to increase model expressivitiy """
        def __init__(self, num_heads: int, d_model: int):
            super().__init__()
            assert d_model % num_heads == 0
            self.dim = d_model // num_heads
            self.num_heads = num_heads
            self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
            self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
    
        def forward(self, query, key, value):
            batch_dim = query.size(0)
            query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
                                 for l, x in zip(self.proj, (query, key, value))]
            x, prob = attention(query, key, value)
            self.prob.append(prob)
            return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))

    这里直接跳到最后的逻辑部分,这部分论文写的比较粗略,需要看下源码才知道在讲啥(也许有大佬不用看)。

    看这里,即是说推理时检出的匹配关系是不考虑最后一行和最后一列的,而是设定阈值,将不合格的匹配过滤掉

            # Get the matches with score above "match_threshold".
            max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
            indices0, indices1 = max0.indices, max1.indices
            mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)  # [0,0...,1,..0]
            mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
            zero = scores.new_tensor(0)
            mscores0 = torch.where(mutual0, max0.values.exp(), zero)
            mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
            valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
            valid1 = mutual1 & valid0.gather(1, indices1)
            indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
            indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))

    推理时代码如下,可见图A和图B互相匹配的结果(按照score的行列取最大值的index)不必完全一致:

                    kpts0, kpts1 = pred['keypoints0'].cpu().numpy()[0], pred['keypoints1'].cpu().numpy()[0]
                    matches, conf = pred['matches0'].cpu().detach().numpy(), pred['matching_scores0'].cpu().detach().numpy()
                    image0 = read_image_modified(image0, opt.resize, opt.resize_float)
                    image1 = read_image_modified(image1, opt.resize, opt.resize_float)
                    valid = matches > -1
                    mkpts0 = kpts0[valid]
                    mkpts1 = kpts1[matches[valid]]
                    mconf = conf[valid]

    然后看求解分配矩阵的部分,couplings为相似度得分矩阵,为其添加了最后一行一列,并赋值为1,注意,这一行一列的(m+n-1)个值实际对应的是同意个内存区域,初始值为1,是可以学习的,在原文提到的约束下,使用sinkhorn(待看)算法求解,求出分配矩阵Z,

    # b(m+1)(n+1), b(m+1), b(n+1)
    def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
        """ Perform Sinkhorn Normalization in Log-space for stability"""
        u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
        for _ in range(iters):
            # [log(m+n) ..., log(n)+log(m+n)] - []
            u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)  # b(m+1)
            v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
        return Z + u.unsqueeze(2) + v.unsqueeze(1)
    
    
    def log_optimal_transport(scores, alpha, iters: int):
        """ Perform Differentiable Optimal Transport in Log-space for stability"""
        b, m, n = scores.shape
        one = scores.new_tensor(1)
        ms, ns = (m*one).to(scores), (n*one).to(scores)
    
        bins0 = alpha.expand(b, m, 1)  # only a new view
        bins1 = alpha.expand(b, 1, n)
        alpha = alpha.expand(b, 1, 1)
    
        # b(m+1)(n+1), 额外行列值为1
        couplings = torch.cat([torch.cat([scores, bins0], -1),  # bmn,bm1->bm(n+1)
                               torch.cat([bins1, alpha], -1)], 1)  # b1n,b11->b1(n+1)
    
        norm = - (ms + ns).log()
        log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])  # m+1: [log(m+n) ..., log(n)+log(m+n)]
        log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])  # n+1: [log(m+n) ..., log(m)+log(m+n)]
        log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)  # b(m+1), b(n+1)
    
        Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
        Z = Z - norm  # multiply probabilities by M+N
        return Z

    损失函数就是最大化这个分配矩阵Z,即下面的scores矩阵,匹配对中肯定不包含dustbin点的,也就是说对dustbin的考量蕴含在sinkhorn中,注意下面的函数调用的参数self.bin_score,这是superglue网络的一个可学习的参数:

            bin_score = torch.nn.Parameter(torch.tensor(1.))
            self.register_parameter('bin_score', bin_score)
    回头看上面的log_optimal_transport代码,每次给couplings的额外行列赋的值就是这个值。
            all_matches = data['all_matches'].permute(1,2,0) # shape=torch.Size([1, 87, 2])
    
            ……
    
            # Run the optimal transport.
            scores = log_optimal_transport(
                scores, self.bin_score,
                iters=self.config['sinkhorn_iterations'])
    
            ……
    
            # check if indexed correctly
            loss = []
            for i in range(len(all_matches[0])):
                x = all_matches[0][i][0]
                y = all_matches[0][i][1]
                loss.append(-torch.log( scores[0][x][y].exp() )) # check batch size == 1 ?

    损失函数部分很好理解,按照公式推测上面的all matches里的匹配值应该是包含无匹配点的(存疑),例如i匹配J+1这样,否则体现不出来损失函数的后两项:

     原文里对分配矩阵的约束如下,

    个人理解这里的a、b对应的N、M应该是打错了,代表的是A、B两图中的无匹配点,对Sinkhorn算法而言,凑齐质量守恒条件即可应用,作者在这里对分配矩阵P_head进行了凑项,相应的对代价矩阵S也要凑,为了凑S,作者采用了上面代码讲解中提到的很奇怪的单个数值tensor内存映射成一行&一列的格式,作者原文对这里的讲解就很简略,感觉就是试了下这样凑代价矩阵,发现挺好用,没有什么其他道理。

    说句题外话,这个S确实不好凑,每对图特征点排列完全随机,额外行列处每个位置一个变量也没什么道理,统一用一个变量反而有一种“阈值”的感觉,虽然推理是对应计算的分配矩阵P_head的额外行列直接扔掉了。

     相对应的,P的约束就很好理解:

  • 相关阅读:
    JSP错题纠错
    org.hibernate.NonUniqueObjectException: a different object with the same identifier value was alread---------程序报错
    SSM框架——详细整合教程(Spring+SpringMVC+MyBatis)
    Spring Mvc 的自定义拦截器
    管理系统-------------SSH框架书写登录和显示用户
    初识的Spring Mvc-----原理
    相识不易,要懂珍惜----------Spring Mvc
    我们的相识,总是那么巧。-------eclipse中搭建maven项目
    初识Web 服务(即Web Service)
    初识Spring框架
  • 原文地址:https://www.cnblogs.com/hellcat/p/15260145.html
Copyright © 2011-2022 走看看