zoukankan      html  css  js  c++  java
  • 光流 | flownet | CVPR2015 | 论文+pytorch代码

    • 文章转自微信公众号「机器学习炼丹术」
    • 作者:炼丹兄(已授权)
    • 作者联系方式:微信cyx645016617(欢迎交流 共同进步)
    • 论文名称:“FlowNet: Learning Optical Flow with Convolutional Networks”
    • 论文链接:http://xxx.itp.ac.cn/abs/1504.06852

    image.png

    0 综述

    论文的主要贡献在我看来有两个:

    • 提出了flownet结构,也就是flownet-v1(现在已经更新到flownet-v2版本),flownet-v1中包含两个版本,一个是flownet-v1S(simple),另一个是flownet-v1C(correlation)。
    • 提出了著名的Flying chairs数据集,飞翔的椅子哈哈,做光流的应该都知道这个有趣的数据集。

    1 flownetsimple

    1.1 特征提取模块

    已知卷积神经网络在具有足够的标记数据的情况下非常擅长学习输入输出关系。因此,采用端到端的学习方法来预测光流:

    给定一个由图像对和光流组成的数据集,我们训练网络以直接从图像中预测x-y的光流场。但是,为此目的,好的架构是什么?

    一个简单的选择是将两个输入图像堆叠在一起,并通过一个相当通用的网络将其输入,从而使网络可以自行决定如何处理图像对以提取运动信息。这种仅包含卷积层的架构为“FlowNetSimple”:

    image.png

    炼丹兄简单讲网络结构:

    • 输入图片是三通道的,把两张图片concat起来,变成6通道的图片;
    • 之后就是常规的:卷积卷积卷积,中间会混杂stride=2的下采样;
    • 然后有几个卷积层输出的特征图会直连到refinement部分(图中右边绿色沙漏型图块),这些特征融合,让图片恢复较大的尺寸;
    • 原始模型的输入图片是384x512,但是最后输出的光流场大小为136x320,这一点需要注意一下。

    1.2 refinement

    我们来看refinement部分,其实这个部分跟Unet也有些类似,但是又有独特的光流模型的特性。

    image.png

    • 可以看到,基本上每一块的特征图,都包含三个部分:
      • 从前一个小尺寸的特征图deconv得到的特征;
      • 从前一个小尺寸的特征图转换成小尺寸的光流场然后deconv得到的特征;
      • 在特征提取过程中,与之尺寸相匹配的特征;
    • 上面的三个特征concat之后,就会变成下一个尺寸的输入特征块,不断循环,让特征的尺寸不断放大;

    1.3 pytorch

    lass FlowNetS(nn.Module):
        expansion = 1
    
        def __init__(self,batchNorm=True):
            super(FlowNetS,self).__init__()
    
            self.batchNorm = batchNorm
            self.conv1   = conv(self.batchNorm,   6,   64, kernel_size=7, stride=2)
            self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
            self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
            self.conv3_1 = conv(self.batchNorm, 256,  256)
            self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
            self.conv4_1 = conv(self.batchNorm, 512,  512)
            self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
            self.conv5_1 = conv(self.batchNorm, 512,  512)
            self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
            self.conv6_1 = conv(self.batchNorm,1024, 1024)
    
            self.deconv5 = deconv(1024,512)
            self.deconv4 = deconv(1026,256)
            self.deconv3 = deconv(770,128)
            self.deconv2 = deconv(386,64)
    
            self.predict_flow6 = predict_flow(1024)
            self.predict_flow5 = predict_flow(1026)
            self.predict_flow4 = predict_flow(770)
            self.predict_flow3 = predict_flow(386)
            self.predict_flow2 = predict_flow(194)
    
            self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
            self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
            self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
            self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                    kaiming_normal_(m.weight, 0.1)
                    if m.bias is not None:
                        constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_(m.weight, 1)
                    constant_(m.bias, 0)
    
        def forward(self, x):
            out_conv2 = self.conv2(self.conv1(x))
            out_conv3 = self.conv3_1(self.conv3(out_conv2))
            out_conv4 = self.conv4_1(self.conv4(out_conv3))
            out_conv5 = self.conv5_1(self.conv5(out_conv4))
            out_conv6 = self.conv6_1(self.conv6(out_conv5))
    
            flow6       = self.predict_flow6(out_conv6)
            flow6_up    = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
            out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)
    
            concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
            flow5       = self.predict_flow5(concat5)
            flow5_up    = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
            out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)
    
            concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
            flow4       = self.predict_flow4(concat4)
            flow4_up    = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
            out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)
    
            concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
            flow3       = self.predict_flow3(concat3)
            flow3_up    = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2)
            out_deconv2 = crop_like(self.deconv2(concat3), out_conv2)
    
            concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
            flow2 = self.predict_flow2(concat2)
    
            if self.training:
                return flow2,flow3,flow4,flow5,flow6
            else:
                return flow2
    
        def weight_parameters(self):
            return [param for name, param in self.named_parameters() if 'weight' in name]
    
        def bias_parameters(self):
            return [param for name, param in self.named_parameters() if 'bias' in name]
    
    • 代码中,deconv都是反卷积层,upsampled_flow也是反卷积层,其他的都是卷积层+bn+LeakyReLU的组合;
    • 代码中的过程和之前讲述的过程完全符合,这一点我在复现的过程中已经核实过了。
    • 最后发现,在训练过程输入的是flow2,flow3等5个尺寸不同的光流场,这自然是为了计算损失,在论文中虽然没有提到损失函数,但是从代码中可以看到使用的是多尺度的损失,(类似于辅助损失的概念)。损失函数在后面会讲解。

    2 损失函数

    import torch
    import torch.nn.functional as F
    
    
    def EPE(input_flow, target_flow, sparse=False, mean=True):
        EPE_map = torch.norm(target_flow-input_flow,2,1)
        batch_size = EPE_map.size(0)
        if sparse:
            # invalid flow is defined with both flow coordinates to be exactly 0
            mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
    
            EPE_map = EPE_map[~mask]
        if mean:
            return EPE_map.mean()
        else:
            return EPE_map.sum()/batch_size
    
    • 先看这一部分,计算损失的关键就是输入光流和输出光流的差值的l2范数;
    • sparse应该是输入的光流场是稀疏光流还是密集光流,我的实验中都是密集光流,所以我就无视sparse内的操作了;
    def multiscaleEPE(network_output, target_flow, weights=None, sparse=False):
        def one_scale(output, target, sparse):
    
            b, _, h, w = output.size()
    
            if sparse:
                target_scaled = sparse_max_pool(target, (h, w))
            else:
                target_scaled = F.interpolate(target, (h, w), mode='area')
            return EPE(output, target_scaled, sparse, mean=False)
    
        if type(network_output) not in [tuple, list]:
            network_output = [network_output]
        if weights is None:
            weights = [0.005, 0.01, 0.02, 0.08, 0.32]  # as in original article
        assert(len(weights) == len(network_output))
    
        loss = 0
        for output, weight in zip(network_output, weights):
            loss += weight * one_scale(output, target_flow, sparse)
        return loss
    
    • 这个部分就是全部的损失了,因为之前模型部分的代码输出的为(flow2,flow3,flow4,flow5,flow6)这样的一个tuple形式;
    • 在函数里,one_scale的方法是把ground truth的光流数据,用intepolate降采样成和flow尺寸大小相同的grount truth,然后再放到EPE中进行计算;
    • 不同尺寸的特征图计算出来的损失,要乘上一个权重,代码中的权重是flownet论文中的原始权重值。

    4 correlation

    image.png
    这里和之前的simple版本的区别,在于:先对图片做了相同的特征处理,类似于孪生网络,然后对于提取的两个特征图,做论文中提出的叫做correlation处理,融合成一个特征图,然后再做类似于simple版本的后续处理

    这里直接看模型代码:

    class FlowNetC(nn.Module):
        expansion = 1
    
        def __init__(self,batchNorm=True):
            super(FlowNetC,self).__init__()
    
            self.batchNorm = batchNorm
            self.conv1      = conv(self.batchNorm,   3,   64, kernel_size=7, stride=2)
            self.conv2      = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
            self.conv3      = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
            self.conv_redir = conv(self.batchNorm, 256,   32, kernel_size=1, stride=1)
    
            self.conv3_1 = conv(self.batchNorm, 473,  256)
            self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
            self.conv4_1 = conv(self.batchNorm, 512,  512)
            self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
            self.conv5_1 = conv(self.batchNorm, 512,  512)
            self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
            self.conv6_1 = conv(self.batchNorm,1024, 1024)
    
            self.deconv5 = deconv(1024,512)
            self.deconv4 = deconv(1026,256)
            self.deconv3 = deconv(770,128)
            self.deconv2 = deconv(386,64)
    
            self.predict_flow6 = predict_flow(1024)
            self.predict_flow5 = predict_flow(1026)
            self.predict_flow4 = predict_flow(770)
            self.predict_flow3 = predict_flow(386)
            self.predict_flow2 = predict_flow(194)
    
            self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
            self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
            self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
            self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                    kaiming_normal_(m.weight, 0.1)
                    if m.bias is not None:
                        constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_(m.weight, 1)
                    constant_(m.bias, 0)
    
        def forward(self, x):
            x1 = x[:,:3]
            x2 = x[:,3:]
    
            out_conv1a = self.conv1(x1)
            out_conv2a = self.conv2(out_conv1a)
            out_conv3a = self.conv3(out_conv2a)
    
            out_conv1b = self.conv1(x2)
            out_conv2b = self.conv2(out_conv1b)
            out_conv3b = self.conv3(out_conv2b)
    
            out_conv_redir = self.conv_redir(out_conv3a)
            out_correlation = correlate(out_conv3a,out_conv3b)
    
            in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1)
    
            out_conv3 = self.conv3_1(in_conv3_1)
            out_conv4 = self.conv4_1(self.conv4(out_conv3))
            out_conv5 = self.conv5_1(self.conv5(out_conv4))
            out_conv6 = self.conv6_1(self.conv6(out_conv5))
    
            flow6       = self.predict_flow6(out_conv6)
            flow6_up    = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
            out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)
    
            concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
            flow5       = self.predict_flow5(concat5)
            flow5_up    = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
            out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)
    
            concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
            flow4       = self.predict_flow4(concat4)
            flow4_up    = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
            out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)
    
            concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
            flow3       = self.predict_flow3(concat3)
            flow3_up    = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2a)
            out_deconv2 = crop_like(self.deconv2(concat3), out_conv2a)
    
            concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1)
            flow2 = self.predict_flow2(concat2)
    
            if self.training:
                return flow2,flow3,flow4,flow5,flow6
            else:
                return flow2
    
        def weight_parameters(self):
            return [param for name, param in self.named_parameters() if 'weight' in name]
    
        def bias_parameters(self):
            return [param for name, param in self.named_parameters() if 'bias' in name]
    
    

    里面的关键在这个部分:

    out_conv_redir = self.conv_redir(out_conv3a)
    out_correlation = correlate(out_conv3a,out_conv3b)
    in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1)
    
    • 这里面的self.conv_redir是卷积层+bn+leakyrelu这一套;
    • 但是关于correlate这个部分中,github作者引用了一个from spatial_correlation_sampler import spatial_correlation_sample,但是这个库并没有在代码中提供,所以关于这个版本的flownet,我也就此作罢。我猜测这个模块是作者引用别人的代码,应该在github主页有说明,但是我这里上github太卡了,回头有空再补充这个知识点把。(不过一般也没有什么人看文章哈哈,没人问我的话,那我就忽视这个坑了2333)

    3 总结

    • flownet在有些情况下确实很好用,训练收敛的还挺快。
    人不可傲慢。
  • 相关阅读:
    div在IOS系统和安卓系统位置不同
    js操作样式
    Css设置文字旋转
    textarea高度自适应
    html引入html页面
    举例说明$POST 、$HTTP_RAW_POST_DATA、php://input三者之间的区别
    PHP获取POST的几种方法
    PHP以xml形式获取POST数据
    使用Composer安装Symfony
    php如何以post形式发送xm并返回xmll数据
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/14663391.html
Copyright © 2011-2022 走看看