zoukankan      html  css  js  c++  java
  • Disentangled Non-Local Neural Networks

    Disentangled Non-Local Neural Networks

    一. 论文简介

    理论(部分感觉不是很合理,不懂大佬思维)和实践相结合的论文,感觉很不错,第一次读很难读懂。

    解决局部感受野的问题,是上一篇论文 的扩展

    主要做的贡献如下(可能之前有人已提出):

    1. 解决局部感受野,设计一个Block

    二. 模块详解

    2.1 论文思路简介

    全部基于论文的内容进行改进,下述将论文A进行代替:

    论文A主要是表达一个函数(f(x_i,x_j)*f(x_j)) ,表示当前像素的表达需要依靠周围像素,前者表示周围像素的权重,后者表示当前像素进行的处理(你也可以直接化简函数(f(x_i,x_j)*x_i)

    论文A中的缺点是 (f(x_i,x_j))(当前像素和周围像素的关系函数)在周围像素比较相似的时候,函数的作用会降低为一元函数,那么就起不到原始的意愿:当前像素和周围像素的关系函数

    此论文发现(f(x_i,x_j))不能仅仅的表示为两者的关系,还应该包含其他部分。论文里的说法是:此二元函数((pairwise))里面包含一个一元函数((unary))+一个二元函数((pairwise)),得分开来表达。

    下附图体现了不同模块表达的函数不同:


    2.2 具体实现

    2.2.1 理论部分

    • 公式(3)的提出,如何得到公式(3),下附图论文只是一笔带过:

    补充:

    (key = unary)(query=piarwise)含义的一样的。

    论文使用白化(减均值)进行操作,公式的目的是获得(key)(query)之间相关性的最大距离,也就是让两个值相互(尽量)独立,这样当周围像素相似才不影响整体的判断。

    其中,(q_i,q_j) 表示(query)的当前特征和周围特征,(k_m,k_n) 表示(key)的当前特征和周围特征。

    论文使用点乘表示两者的相关性,因为写高斯函数比较复杂,所以简化操作(见论文A)。

    那么以下的公式就比较明了,笔者进行化解: (q_i^T*k_m-q_i^T*k_n-k_m^T*q_j) ,第一项表示两者的相关性(肯定越大越好),第二项和第三项表示对对方周围像素的关联性(肯定越小越好),我们最大化这个函数,就能保住两者之间差异性最大化。其实第一项也可以表示成差异性,第二三项表示成关联性,这样更容易理解。

    以下公式分子是差异性 ,分母是归一化的求和。

    • 公式(4)作者也是一笔带过

    补充:

    论文前面一直说:(q_i^Tk_j=(q_i-mu_q)^T(k_j-mu_k)) ,为什么到这里突然出现后面三项?

    因为论文一直在说一件事,(f(x_i,x_j)) 不仅仅包含(q_i^Tk_j),还影藏的包含了一元函数

    一元函数到底是什么?

    既然是未知的,那就全部列出来,(u_q^Tk_j+q_i^Tu_k+u_q^Tu_k) ,这里是上面式子展开的全部组合,具体哪个项的作用具体是什么?论文未进一步讨论。

    • 公式在视觉上的体现(论文3.2节

    这部分主要对理论的实际展现,通过label和operate的边界交集进行可视化分析

    • 反向推导公式的好处(论文3.3节

    通过理论反向推导公式的优势,反向链式求导,add比multi更具有分离性

    • 推导(附录)

    其中hessian矩阵小于0,获得最大值

    2.2.2 具体实现

    下图只是一个整体流程图,具体实现得结合公式

    主要有两个实现版本,感觉都不全。

    g_k = conv(x), g_q = conv(x), g_m=conv(x), g_w=conv(x)

    g_k= = g_k - k_mean, g_q = g_q - q_mean

    g_pnl = soft_max( g_k * g_q ), g_m = soft_max(g_m * q_mean) #这里得加上公式里的内容(u_q^Tk_j)

    g_dnl = g_pnl + g_m

    g_dnl = g_v*g_dnl

    x = x + g_dnl

    import torch
    import torch.nn as nn
    from mmcv.cnn import constant_init, normal_init
    
    from ..utils import ConvModule
    from mmdet.ops import ContextBlock
    
    from torch.nn.parameter import Parameter
    
    class NonLocal2D(nn.Module):
        """Non-local module.
        See https://arxiv.org/abs/1711.07971 for details.
        Args:
            in_channels (int): Channels of the input feature map.
            reduction (int): Channel reduction ratio.
            use_scale (bool): Whether to scale pairwise_weight by 1/inter_channels.
            conv_cfg (dict): The config dict for convolution layers.
                (only applicable to conv_out)
            norm_cfg (dict): The config dict for normalization layers.
                (only applicable to conv_out)
            mode (str): Options are `embedded_gaussian` and `dot_product`.
        """
    
        def __init__(self,
                     in_channels,
                     reduction=2,
                     use_scale=True,
                     conv_cfg=None,
                     norm_cfg=None,
                     mode='embedded_gaussian',
                     whiten_type=None,
                     temp=1.0,
                     downsample=False,
                     fixbug=False,
                     learn_t=False,
                     gcb=None):
            super(NonLocal2D, self).__init__()
            self.in_channels = in_channels
            self.reduction = reduction
            self.use_scale = use_scale
            self.inter_channels = in_channels // reduction
            self.mode = mode
            assert mode in ['embedded_gaussian', 'dot_product', 'gaussian']
            if mode == 'gaussian':
                self.with_embedded = False
            else:
                self.with_embedded = True
            self.whiten_type = whiten_type
            assert whiten_type in [None, 'channel', 'bn-like']  # TODO: support more
            self.learn_t = learn_t
            if self.learn_t:
                self.temp = Parameter(torch.Tensor(1))
                self.temp.data.fill_(temp)
            else:
                self.temp = temp
            if downsample:
                self.downsample = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
            else:
                self.downsample = None
            self.fixbug=fixbug
    
            assert gcb is None or isinstance(gcb, dict)
            self.gcb = gcb
            if gcb is not None:
                self.gc_block = ContextBlock(inplanes=in_channels, **gcb)
            else:
                self.gc_block = None
    
            # g, theta, phi are actually `nn.Conv2d`. Here we use ConvModule for
            # potential usage.
            self.g = ConvModule(
                self.in_channels,
                self.inter_channels,
                kernel_size=1,
                activation=None)
            if self.with_embedded:
                self.theta = ConvModule(
                    self.in_channels,
                    self.inter_channels,
                    kernel_size=1,
                    activation=None)
                self.phi = ConvModule(
                    self.in_channels,
                    self.inter_channels,
                    kernel_size=1,
                    activation=None)
            self.conv_out = ConvModule(
                self.inter_channels,
                self.in_channels,
                kernel_size=1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                activation=None)
    
            self.init_weights()
    
        def init_weights(self, std=0.01, zeros_init=True):
            transform_list = [self.g]
            if self.with_embedded:
                transform_list.extend([self.theta, self.phi])
            for m in transform_list:
                normal_init(m.conv, std=std)
            if zeros_init:
                constant_init(self.conv_out.conv, 0)
            else:
                normal_init(self.conv_out.conv, std=std)
    
        def embedded_gaussian(self, theta_x, phi_x):
            # pairwise_weight: [N, HxW, HxW]
            pairwise_weight = torch.matmul(theta_x, phi_x)
            if self.use_scale:
                # theta_x.shape[-1] is `self.inter_channels`
                if self.fixbug:
                    pairwise_weight /= theta_x.shape[-1]**0.5
                else:
                    pairwise_weight /= theta_x.shape[-1]**-0.5
            if self.learn_t:
                pairwise_weight = pairwise_weight * nn.functional.softplus(self.temp) # stable training
            else:
                pairwise_weight = pairwise_weight / self.temp
            pairwise_weight = pairwise_weight.softmax(dim=-1)
            return pairwise_weight
    
        def gaussian(self, theta_x, phi_x):
            return self.embedded_gaussian(theta_x, phi_x)
    
        def dot_product(self, theta_x, phi_x):
            # pairwise_weight: [N, HxW, HxW]
            pairwise_weight = torch.matmul(theta_x, phi_x)
            pairwise_weight /= pairwise_weight.shape[-1]
            return pairwise_weight
    
        def forward(self, x):
            n, _, h, w = x.shape
            if self.downsample:
                down_x = self.downsample(x)
            else:
                down_x = x
    
            # g_x: [N, H'xW', C], VALUE?
            g_x = self.g(down_x).view(n, self.inter_channels, -1)
            g_x = g_x.permute(0, 2, 1)
    
            # theta_x: [N, HxW, C], QUERY?
            if self.with_embedded:
                theta_x = self.theta(x).view(n, self.inter_channels, -1)
                theta_x = theta_x.permute(0, 2, 1)
            else:
                theta_x = x.view(n, self.in_channels, -1)
                theta_x = theta_x.permute(0, 2, 1)
    
            # phi_x: [N, C, H'xW'], KEY?
            if self.with_embedded:
                phi_x = self.phi(down_x).view(n, self.inter_channels, -1)
            else:
                phi_x = x.view(n, self.in_channels, -1)
    
            # whiten
            if self.whiten_type == "channel":
                theta_x_mean = theta_x.mean(2).unsqueeze(2)
                phi_x_mean = phi_x.mean(2).unsqueeze(2)
                theta_x -= theta_x_mean
                phi_x -= phi_x_mean
            elif self.whiten_type == 'bn-like':
                theta_x_mean = theta_x.mean(2).mean(0).unsqueeze(0).unsqueeze(2)
                phi_x_mean = phi_x.mean(2).mean(0).unsqueeze(0).unsqueeze(2)
                theta_x -= theta_x_mean
                phi_x -= phi_x_mean
    
            pairwise_func = getattr(self, self.mode)
            # pairwise_weight: [N, HxW, H'xW']
            pairwise_weight = pairwise_func(theta_x, phi_x)
    
            # y: [N, HxW, C]
            y = torch.matmul(pairwise_weight, g_x)
            # y: [N, C, H, W]
            y = y.permute(0, 2, 1).reshape(n, self.inter_channels, h, w)
    
    
            # gc block
            if self.gcb:
                output = self.gc_block(x) + self.conv_out(y)
            else:
                output = x + self.conv_out(y)
    
            return output
    
    import torch
    import torch.nn.functional as F
    #from libs import InPlaceABN, InPlaceABNSync
    from torch import nn
    from torch.nn import init
    import math
    
    
    class _NonLocalNd_bn(nn.Module):
    
        def __init__(self, dim, inplanes, planes, downsample, use_gn, lr_mult, use_out, out_bn, whiten_type, temperature,
                     with_gc, with_unary):
            assert dim in [1, 2, 3], "dim {} is not supported yet".format(dim)
            # assert whiten_type in ['channel', 'spatial']
            if dim == 3:
                conv_nd = nn.Conv3d
                if downsample:
                    max_pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
                else:
                    max_pool = None
                bn_nd = nn.BatchNorm3d
            elif dim == 2:
                conv_nd = nn.Conv2d
                if downsample:
                    max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
                else:
                    max_pool = None
                bn_nd = nn.BatchNorm2d
            else:
                conv_nd = nn.Conv1d
                if downsample:
                    max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
                else:
                    max_pool = None
                bn_nd = nn.BatchNorm1d
    
            super(_NonLocalNd_bn, self).__init__()
            self.conv_query = conv_nd(inplanes, planes, kernel_size=1)
            self.conv_key = conv_nd(inplanes, planes, kernel_size=1)
            if use_out:
                self.conv_value = conv_nd(inplanes, planes, kernel_size=1)
                self.conv_out = conv_nd(planes, inplanes, kernel_size=1, bias=False)
            else:
                self.conv_value = conv_nd(inplanes, inplanes, kernel_size=1, bias=False)
                self.conv_out = None
            if out_bn:
                self.out_bn = nn.BatchNorm2d(inplanes)
            else:
                self.out_bn = None
            if with_gc:
                self.conv_mask = conv_nd(inplanes, 1, kernel_size=1)
            if 'bn_affine' in whiten_type:
                self.key_bn_affine = nn.BatchNorm1d(planes)
                self.query_bn_affine = nn.BatchNorm1d(planes)
            if 'bn' in whiten_type:
                self.key_bn = nn.BatchNorm1d(planes, affine=False)
                self.query_bn = nn.BatchNorm1d(planes, affine=False)
            self.softmax = nn.Softmax(dim=2)
            self.downsample = max_pool
            # self.norm = nn.GroupNorm(num_groups=32, num_channels=inplanes) if use_gn else InPlaceABNSync(num_features=inplanes)
            self.gamma = nn.Parameter(torch.zeros(1))
            self.scale = math.sqrt(planes)
            self.whiten_type = whiten_type
            self.temperature = temperature
            self.with_gc = with_gc
            self.with_unary = with_unary
    
            self.reset_parameters()
            self.reset_lr_mult(lr_mult)
    
        def reset_parameters(self):
    
            for m in self.modules():
                if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                    init.normal_(m.weight, 0, 0.01)
                    if m.bias is not None:
                        init.zeros_(m.bias)
                    m.inited = True
            # init.constant_(self.norm.weight, 0)
            # init.constant_(self.norm.bias, 0)
            # self.norm.inited = True
    
        def reset_lr_mult(self, lr_mult):
            if lr_mult is not None:
                for m in self.modules():
                    m.lr_mult = lr_mult
            else:
                print('not change lr_mult')
    
        def forward(self, x):
            # [N, C, T, H, W]
            residual = x
            # [N, C, T, H', W']
            if self.downsample is not None:
                input_x = self.downsample(x)
            else:
                input_x = x
    
            # [N, C', T, H, W]
            query = self.conv_query(x)
            # [N, C', T, H', W']
            key = self.conv_key(input_x)
            value = self.conv_value(input_x)
    
            # [N, C', H x W]
            query = query.view(query.size(0), query.size(1), -1)
            # [N, C', H' x W']
            key = key.view(key.size(0), key.size(1), -1)
            value = value.view(value.size(0), value.size(1), -1)
    
            if 'channel' in self.whiten_type:
                key_mean = key.mean(2).unsqueeze(2)
                query_mean = query.mean(2).unsqueeze(2)
                key -= key_mean
                query -= query_mean
            if 'spatial' in self.whiten_type:
                key_mean = key.mean(1).unsqueeze(1)
                query_mean = query.mean(1).unsqueeze(1)
                key -= key_mean
                query -= query_mean
            if 'bn_affine' in self.whiten_type:
                key = self.key_bn_affine(key)
                query = self.query_bn_affine(query)
            if 'bn' in self.whiten_type:
                key = self.key_bn(key)
                query = self.query_bn(query)
            if 'ln_nostd' in self.whiten_type :
                key_mean = key.mean(1).mean(1).view(key.size(0), 1, 1)
                query_mean = query.mean(1).mean(1).view(query.size(0), 1, 1)
                key -= key_mean
                query -= query_mean
    
            # [N, T x H x W, T x H' x W']
            sim_map = torch.bmm(query.transpose(1, 2), key)
            sim_map = sim_map / self.scale
            sim_map = sim_map / self.temperature
            sim_map = self.softmax(sim_map)
    
            # [N, T x H x W, C']
            out_sim = torch.bmm(sim_map, value.transpose(1, 2))
            # [N, C', T x H x W]
            out_sim = out_sim.transpose(1, 2)
            # [N, C', T,  H, W]
            out_sim = out_sim.view(out_sim.size(0), out_sim.size(1), *x.size()[2:])
            # if self.norm is not None:
            #     out = self.norm(out)
            out_sim = self.gamma * out_sim
            
            if self.with_unary:
                if query_mean.shape[1] ==1:
                    query_mean = query_mean.expand(-1, key.shape[1], -1)
                unary = torch.bmm(query_mean.transpose(1,2),key)
                unary = self.softmax(unary)
                out_unary = torch.bmm(value, unary.permute(0,2,1)).unsqueeze(-1)
                out_sim = out_sim + out_unary
    
            # out = residual + out_sim
    
            if self.with_gc:
                # [N, 1, H', W']
                mask = self.conv_mask(input_x)
                # [N, 1, H'x W']
                mask = mask.view(mask.size(0), mask.size(1), -1)
                mask = self.softmax(mask)
                # [N, C', 1, 1]
                out_gc = torch.bmm(value, mask.permute(0, 2, 1)).unsqueeze(-1)
                out_sim = out_sim + out_gc
    
            # [N, C, T,  H, W]
            if self.conv_out is not None:
                out_sim = self.conv_out(out_sim)
            if self.out_bn:
                out_sim = self.out_bn(out_sim)
    
            out = out_sim + residual
    
            return out
    
    
    class NonLocal2d_bn(_NonLocalNd_bn):
    
        def __init__(self, inplanes, planes, downsample=True, use_gn=False, lr_mult=None, use_out=False, out_bn=False,
                     whiten_type=['channel'], temperature=1.0, with_gc=False, with_unary=False):
            super(NonLocal2d_bn, self).__init__(dim=2, inplanes=inplanes, planes=planes, downsample=downsample,
                                                use_gn=use_gn, lr_mult=lr_mult, use_out=use_out, out_bn=out_bn,
                                                whiten_type=whiten_type, temperature=temperature, with_gc=with_gc, with_unary=with_unary)
    
  • 相关阅读:
    个人阅读2
    代码复审
    PairProject 总结
    Pairproject 移山之道 阅读随笔和一些问题
    M1/M2个人总结
    团队项目个人总结
    个人阅读作业2
    代码互审
    《移山之道》读后感
    Individual Project
  • 原文地址:https://www.cnblogs.com/wjy-lulu/p/13680844.html
Copyright © 2011-2022 走看看