zoukankan      html  css  js  c++  java
  • [论文理解] Why do deep convolutional networks generalize so poorly to small image transformations?

    Why do deep convolutional networks generalize so poorly to small image transformations?

    Intro

    CNN的设计初衷是为了使得模型具有微小平移、旋转不变性,而实际上本文通过实验验证了现在比较流行的神经网络都已经丧失了这样的能力,甚至图像只水平移动一个像素,预测的结果都将会发生很大的变化。之所以如此,作者认为CNN的下采样背离了隆奎斯特采样定理,就连augmentation也并不能缓解微小变化不变性的丧失。

    Ignoring the Sampling Theorem

    1. 随机crop一张图然后resize,平移一个像素
    2. 原图rescale黑色填充背景,平移一个像素
    3. 原图rescale下方插值,平移一个像素
    4. 原图rescale,rescale相差一个像素值

    结果如图,对人而言是看不出来什么区别的,但是网络的预测结果却明显不同。

    仅仅这么一个微小变化,使得网络的输出发生这么大的变化,这显然是很不合理的,因此作者认为有下面的解释。

    作者认为之所以CNN丧失了平移不变性是因为现在CNN的设计忽略了采样定理,也就是奈奎斯特采样定理。作者认为,CNN中的下采样操作,如stride 2 conv和pooling,由于采样频率没满足大于等于两倍的最高频,所以会导致微小形变不变性的丧失。

    如何理解?

    上图我做了三个实验,分别是原图、pooling后的图、conv2d(stride=4)后的图,下面是对应的频谱图。可以看出来,pooling、conv下采样层都是保留高频部分,丢弃低频部分。

    可见,pooling层和conv层所做的下采样都是保留高频,丢弃低频成分。因此信号的频率都是高频,而采样频率是和stride相关的,可以理解为每隔多少像素采样一次,因此如果stride太大,采样频率就会过小,这样就不满足采样定理了。

    import cv2
    import numpy as np
    import matplotlib.pyplot as plt
    import skimage.measure
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class FFT(object):
        def __init__(self):
            pass
        def __call__(self,x):
            f = np.fft.fft2(x)
            fshift = np.fft.fftshift(f)
            s1 = np.log(np.abs(fshift))
            return s1
        @staticmethod
        def pooling(feature_map,kernel_size= 2):
            inpt = torch.tensor(feature_map[np.newaxis,np.newaxis,:,:],dtype = torch.float32)
            
            out = F.max_pool2d(inpt,kernel_size = kernel_size)
            
            return out.numpy()[0,0,:,:]
            #return skimage.measure.block_reduce(feature_map, (kernel_size,kernel_size), np.max)
        @staticmethod
        def conv2d(feature_map,kernel_size= 2,stride = 1):
            inpt = torch.tensor(feature_map[np.newaxis,np.newaxis,:,:],dtype = torch.float32)
            kernel = torch.ones((1,1,kernel_size,kernel_size)) / kernel_size
            out = F.conv2d(inpt,weight = kernel,stride = stride)
            return out.numpy()[0,0,:,:]
            
    if __name__ == "__main__":    
        fft = FFT()
        img = cv2.imread('/home/xueaoru/图片/test002.jpg', 0)
        img2 = fft.pooling(img,4)
        img3 = fft.conv2d(img,kernel_size = 5,stride = 4)
        s1 = fft(img)
        s2 = fft(img2)
        s3 = fft(img3)
        plt.subplot(231)
        plt.imshow(img, 'gray')
        plt.title('original')
    
        plt.subplot(232)
        plt.imshow(img2, 'gray')
        plt.title('original2')
    
    
        plt.subplot(233)
        plt.imshow(img3, 'gray')
        plt.title('original3')
    
        plt.subplot(234)
        plt.imshow(s1,'gray')
        plt.title('center1')
    
    
        plt.subplot(235)
        plt.imshow(s2,'gray')
        plt.title('center2')
    
    
        plt.subplot(236)
        plt.imshow(s3,'gray')
        plt.title('center3')
    
        plt.show()
    

    此外,作者还说明了一件事:

    网络越深,微小形变不变性丧失的越严重。上图可以看出vgg16比较浅,还看不出什么变化,但是resnet50已经丧失很多了,densenet201也是一样。

    Why don't modern CNNs learn to be invariant from data?

    我们训练之前一般会做很多image augmentation的,但是呢,为什么image augmentation也没能缓解平移不变性的丧失呢?

    一个简单的回答就是网络学习到一种简单的判别函数,使得网络对训练集的图片的变换具有微小形变不变性,而对于测试集或者说没见过的数据的微小形变不具有不变性。这称为dataset bias。

    上图是imagenet上训练集里狗的眼睛距离的统计结果,也就是说,一般情况下狗的眼睛肯定在这个range内,网络可以work,然是测试数据的眼睛距离如果不在这个range内的话,网络也许就不work了,因此要求网络学习到这个range之外的数据是很难的,虽然image augmentation有所帮助,但其实很难泛化到所有情况,特别是测试集和训练集不是同一分布的时候。

    Possible Solutions

    1. 将下采样替换为:stride 1 maxpooling + stride 2 conv.也就是采样之前先blur,但是对于大数据效果甚微。
    2. data augmentation:多增加额外的augmentation。
    3. 减少二次采样操作:二次采样会导致可能不满足采样定律,所以减少二次采样可以保持不变性。
  • 相关阅读:
    Percona 工具包 pt-online-schema-change 简介
    MySQL 中NULL和空值的区别
    MySQL二进制日志文件过期天数设置说明
    MySQL大小写敏感说明
    SpringBoot 配置Druid:不显示SQL监控 —(*) property for user to setup
    IDEA 启用/禁用 Run Dashboard
    java.lang.IllegalAccessException: Class XX can not access a member of class XXX with modifiers "private static"
    Swagger2常用注解说明
    更改IDEA默认使用JDK1.5编译项目
    Spring Boot : Swagger 2
  • 原文地址:https://www.cnblogs.com/aoru45/p/12222788.html
Copyright © 2011-2022 走看看