zoukankan      html  css  js  c++  java
  • 『论文笔记』Swin Transformer

    https://zhuanlan.zhihu.com/p/361366090

    目前transform的两个非常严峻的问题

    1. 受限于图像的矩阵性质,一个能表达信息的图片往往至少需要几百个像素点,而建模这种几百个长序列的数据恰恰是Transformer的天生缺陷;
    2. 目前的基于Transformer框架更多的是用来进行图像分类,对实例分割这种密集预测的场景Transformer并不擅长解决。

    在Swin Transformer之前的ViT和iGPT,它们都使用了小尺寸的图像作为输入,这种直接resize的策略无疑会损失很多信息。与它们不同的是,Swin Transformer的输入是图像的原始尺寸另外Swin Transformer使用的是CNN中最常用的层次的网络结构,在CNN中一个特别重要的一点是随着网络层次的加深,节点的感受野也在不断扩大,这个特征在Swin Transformer中也是满足的。Swin Transformer的这种层次结构,也赋予了它可以像FPN,U-Net等结构实现可以进行分割或者检测的任务。


    图1:Swin Transformer和ViT的对比

     

     

    图2:Swin-T的网络结构 

    Patch Partition/Patch Merging

     在图2中,输入图像之后是一个Patch Partition,再之后是一个Linear Embedding层,这两个加在一起其实就是一个Patch Merging层(至少上面的源码中是这么实现的)。这一部分的源码如下:

    class PatchMerging(nn.Module):
        def __init__(self, in_channels, out_channels, downscaling_factor):
            super().__init__()
            self.downscaling_factor = downscaling_factor
            self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
            self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
    
        def forward(self, x):
            b, c, h, w = x.shape
            new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
            x = self.patch_merge(x) # (1, 48, 3136)
            x = x.view(b, -1, new_h, new_w).permute(0, 2, 3, 1) # (1, 56, 56, 48)
            x = self.linear(x) # (1, 56, 56, 96)
            return x 

     

    Patch Merging的作用是对图像进行降采样,类似于CNN中Pooling层。Patch Merging是主要是通过nn.Unfold函数实现降采样的,nn.Unfold的功能是对图像进行滑窗,相当于卷积操作的第一步,因此它的参数包括窗口的大小和滑窗的步长。根据源码中给出的超参我们知道这一步降采样的比例是[公式] ,因此经过nn.Unfold之后会得到 [公式] 个长度为[公式] 的特征向量,其中 [公式] 是输入到这个stage的Feature Map的通道数,第一个stage的输入是RGB图像,因此通接着的viewpermute是将得到的向量序列还原到 [公式] 的二维矩阵,linear是将长度是 [公式] 的特征向量映射到out_channels的长度,因此stage-1的Patch Merging的输出向量维度是 [公式] ,对比源码的注释,这里省略了第一个batch为 [公式] 的维度。

    可以看出Patch Partition/Patch Merging起到的作用像是CNN中通过带有步长的滑窗来降低分辨率,再通过 [公式] 卷积来调整通道数。不同的是在CNN中最常使用的降采样的最大池化或者平均池化往往会丢弃一些信息,例如最大池化会丢弃一个窗口内的地响应值,而Patch Merging的策略并不会丢弃其它响应,但它的缺点是带来运算量的增加。在一些需要提升模型容量的场景中,我们其实可以考虑使用Patch Merging来替代CNN中的池化。

    作用就是更好的降采样,CNN中最常使用的降采样的最大池化或者平均池化往往会丢弃一些信息,例如最大池化会丢弃一个窗口内的地响应值,而Patch Merging的策略并不会丢弃其它响应,但它的缺点是带来运算量的增加。

    Swin Transformer的Stage

    Swin Transformer的一个stage便可以看做由Patch Merging和Swin Transformer Block组成,

    class StageModule(nn.Module):
        def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                     relative_pos_embedding):
            super().__init__()
            assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'
    
            self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                                downscaling_factor=downscaling_factor)
    
            self.layers = nn.ModuleList([])
            for _ in range(layers // 2):
                self.layers.append(nn.ModuleList([
                    SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                              shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                    SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                              shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                ]))
    
        def forward(self, x):
            x = self.patch_partition(x)
            for regular_block, shifted_block in self.layers:
                x = regular_block(x)
                x = shifted_block(x)
            return x.permute(0, 3, 1, 2)

    由窗口多头自注意层(window multi-head self-attention, W-MSA)和移位窗口多头自注意层(shifted-window multi-head self-attention, SW-MSA)组成,如图3所示。由于这个原因,Swin Transformer的层数要为2的整数倍,一层提供给W-MSA,一层提供给SW-MSA。

    图3:Swin Transformer Block的网络结构

    从图3中我们可以看出输入到该stage的特征 [公式] 先经过LN进行归一化,再经过W-MSA进行特征的学习,接着的是一个残差操作得到 [公式] 。接着是一个LN,一个MLP以及一个残差,得到这一层的输出特征 [公式] 。SW-MSA层的结构和W-MSA层类似,不同的是计算特征部分分别使用了SW-MSA和W-MSA,可以从上面的源码中看出它们除了shifted的这个bool值不同之外,其它的值是保持完全一致的。这一部分可以表示为式(2)。

    [公式]

    一个Swin Block的源码如下所示,和论文中图不同的是,LN层(PerNorm函数)从Self-Attention之前移到了Self-Attention之后。

    class Residual(nn.Module):
        def __init__(self, fn):
            super().__init__()
            self.fn = fn
    
        def forward(self, x, **kwargs):
            return self.fn(x, **kwargs) + x
    
    class PreNorm(nn.Module):
        def __init__(self, dim, fn):
            super().__init__()
            self.norm = nn.LayerNorm(dim)
            self.fn = fn
    
        def forward(self, x, **kwargs):
            return self.fn(self.norm(x), **kwargs)
    
    class SwinBlock(nn.Module):
        def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
            super().__init__()
            self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim, heads=heads, head_dim=head_dim, shifted=shifted, window_size=window_size, relative_pos_embedding=relative_pos_embedding)))
            self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))
    
        def forward(self, x):
            x = self.attention_block(x)
            x = self.mlp_block(x)
            return x

    窗口多头自注意力(Window Multi-head Self Attention,W-MSA)

    顾名思义,就是个在窗口的尺寸上进行Self-Attention计算,与SW-MSA不同的是,它不会进行窗口移位,它们的源码如下。我们这里先忽略shiftedTrue的情况

    class WindowAttention(nn.Module):
        def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
            super().__init__()
            inner_dim = head_dim * heads
            self.heads = heads
            self.scale = head_dim ** -0.5
            self.window_size = window_size
            self.relative_pos_embedding = relative_pos_embedding # (13, 13)
            self.shifted = shifted
    
            if self.shifted:
                displacement = window_size // 2
                self.cyclic_shift = CyclicShift(-displacement)
                self.cyclic_back_shift = CyclicShift(displacement)
                self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, upper_lower=True, left_right=False), requires_grad=False) # (49, 49)
                self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,pper_lower=False, left_right=True), requires_grad=False) # (49, 49)
    
            self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
            if self.relative_pos_embedding:
                self.relative_indices = get_relative_distances(window_size) + window_size - 1
                self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
            else:
                self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))
    
            self.to_out = nn.Linear(inner_dim, dim)
    
        def forward(self, x):
            if self.shifted:
                x = self.cyclic_shift(x)
    
            b, n_h, n_w, _, h = *x.shape, self.heads # [1, 56, 56, _, 3]
            qkv = self.to_qkv(x).chunk(3, dim=-1) # [(1,56,56,96), (1,56,56,96), (1,56,56,96)]
            nw_h = n_h // self.window_size # 8
            nw_w = n_w // self.window_size # 8
            # 分成 h/M * w/M 个窗口
            q, k, v = map( lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', h=h, w_h=self.window_size, w_w=self.window_size), qkv)
            # q, k, v : (1, 3, 64, 49, 32)
            # 按窗口个数的self-attention
            dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale # (1,3,64,49,49)
    
            if self.relative_pos_embedding:
                dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
            else:
                dots += self.pos_embedding
    
            if self.shifted:
                dots[:, :, -nw_w:] += self.upper_lower_mask
                dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
    
            attn = dots.softmax(dim=-1) # (1,3,64,49,49)
            out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
            out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w) # (1, 56, 56, 96) # 窗口合并
            out = self.to_out(out)
            if self.shifted:
                out = self.cyclic_back_shift(out)
            return out

    forward函数中首先计算的是Transformer中介绍的 [公式] , [公式] , [公式] 三个特征。所以to_qkv()函数就是一个线性变换,这里使用了一个实现小技巧,即只使用了一个一层隐层节点数为inner_dim*3的线性变换,然后再使用chunk(3)操作将它们切开。因此qkv是一个长度为3的Tensor,每个Tensor的维度是 [公式] 。

    之后的map函数是实现W-MSA中的W最核心的代码,它是通过einopsrearrange实现的。einops是一个可读性非常高的实现常见矩阵操作的python包,例如矩阵转置,矩阵复制,矩阵reshape等操作。最终通过这个操作得到了3个独立的窗口的权值矩阵,它们的维度是 [公式] ,这4个值的意思分别是:

    • [公式] :多头自注意力的头的个数;
    • [公式] :窗口的个数,首先通过Patch Merging将图像的尺寸降到 [公式] ,因为窗口的大下为[公式] ,所以总共剩下 [公式] 个窗口;
    • [公式] :窗口的像素的个数;
    • [公式] :隐层节点的个数。

    Swin Transformer将计算区域控制在了以窗口为单位的策略极大减轻了网络的计算量,将复杂度降低到了图像尺寸的线性比例。传统的MSA和W-MSA的复杂度分别是:

    [公式]

    (3)式的计算忽略了softmax的占用的计算量,这里以​ [公式] 为例,它的具体构成如下:

    1. 代码中的to_qkv()函数,即用于生成 [公式] 三个特征向量:其中​ [公式] 。 [公式] ​的维度是 [公式] ​, [公式] ​的维度是 [公式] ​,那么这三项的复杂度是​ [公式] ;
    2. 计算 [公式] : [公式] 的维度均是 [公式] ,因此它的复杂度是 [公式] ;
    3. softmax之后乘 [公式] 得到 [公式] :因为 [公式] 的维度是 [公式] ,所以它的复杂度是​ [公式] ;
    4. [公式] 乘 [公式] 矩阵​得到最终输出,对应代码中的to_out()函数:它的复杂度是​ [公式] 。

    通过Transformer的计算公式(4),我们可以有更直观一点的理解,在Transformer一文中我们介绍过Self-Attention是通过点乘的方式得到Query矩阵和Key矩阵的相似度,即(4)式中的 [公式] 。然后再通过这个相似度匹配Value。因此这个相似度的计算时通过逐个元素进行点乘计算得到的。如果比较的范围是一个图像,那么计算的瓶颈就是整个图的逐像素比较,因此复杂度是 [公式] 。而W-MSA是在窗口内的逐像素比较,因此复杂度是 [公式] ,其中 [公式] 是W-MSA的窗口的大小。

    [公式]

    回到代码,接着的dots变量便是我们刚刚介绍的 [公式] 操作。接着是加入相对位置编码,我们放到最后介绍。接着的attn以及einsum便是完成了式(4)的整个流程。然后再次使用rearrange将维度再调整回 [公式] 。最后通过to_out将维度调整为超参设置的输出维度的值。

    这里我们介绍一下W-MSA的相对位置编码,首先这个位置编码是加在乘以完归一化尺度之后的dots变量上的,因此 [公式] 的计算方式变为式(5)。因为W-MSA是以窗口为单位进行特征匹配的,因此相对位置编码的范围也应该是以窗口为单位,它的具体实现见下面代码。相对位置编码的具体思想参考UniLMv2[8]。

    [公式]

    def get_relative_distances(window_size):
        indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
        distances = indices[None, :, :] - indices[:, None, :]
        return distances

    单独的使用W-MSA得到的网络的建模能力是非常差的,因为它将每个窗口当做一个独立区域计算而忽略了窗口之间交互的必要性,基于这个动机,Swin Transformer提出了SW-MSA。

    SW-MSA

    SW-MSA的的位置是接在W-MSA层之后的,因此只要我们提供一种和W-MSA不同的窗口切分方式便可以实现跨窗口的通信。

    SW-MSA的实现方式如图4所示。我们上面说过,输入到Stage-1的图像尺寸是 [公式] 的(图4.(a)),那么W-MSA的窗口切分的结果如图4.(b)所示。

     SW-MSA的思想很简单,将图像各循环上移和循环左移半个窗口的大小,那么图4.(c)的蓝色和红色区域将分别被移动到图像的下侧和右侧,如图4.(d)。在移位的基础上再按照W-MSA切分窗口,就会得到和W-MSA不同的窗口切分方式,如图4.(d)中红色和蓝色分别是W-MSA和SW-MSA的切分窗口的结果。这一部分可以通过pytorch的roll函数实现,源码中是CyclicShift函数。

    class CyclicShift(nn.Module):
        def __init__(self, displacement):
            super().__init__()
            self.displacement = displacement
    
        def forward(self, x):
            return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

    这种窗口切分方式引入了一个新的问题,即在移位图像的最后一行和最后一列各引入了一块移位过来的区域,如图4.(d)。只需要对比图4.(d)中的一个窗口中相同颜色的区域计算注意力,我们以图4.(d)左下角的区域(1)为例来说明SW-MSA是怎么实现这个功能的。

    区域(1)的计算如图5所示。首先一个 [公式] 大小的窗口通过线性预算得到 [公式] , [公式] , [公式] 三个权值,如我们介绍的,它的维度是 [公式] 。在这个49中,前28个是按照滑窗的方式遍历区域(1)中的前48个像素得到的,后21个则是遍历区域(1)的下半部分得到的,此时他们对应的位置关系依旧保持上黄下蓝的性质。

    接着便是计算 [公式] ,在图中相同颜色区域的相互计算后会依旧保持颜色,而黄色和蓝色区域计算后会变成绿色,而绿色的部分便是无意义的相似度。在论文中使用了upper_lower_mask将其掩码掉,upper_lower_mask是由 [公式] 和无穷大组成的二值矩阵,最后通过单位加之后得到最终的dots变量。

    下面两张图解释了右下角块1(7*7*32)和有边块的处理方式:

     

     

    最后我们介绍一下Swin Transformer的输出层,在stage-4完成计算后,特征的维度是 [公式] 。Swin Transformer先通过一个Global Average Pooling得到长度为768的特征向量,再通过一个LN和一个全连接得到最终的预测结果,如式(6)。

    [公式]

     
     
     
     
     
  • 相关阅读:
    总结7.13 tp5模板布局
    总结7.13 tp5图像处理
    Flask数据库
    java学习day72-JT项目10(Nginx服务器/tomcat部署/数据库高可用)
    java学习day71-Linux学习(基本指令)
    java学习day71-JT项目09(Linux/JDK/Mariadb/tomcat部署)
    java学习day70-JT项目08(图片回显/Nginx)
    java学习day69-JT项目07-(商品/详情一对一操作//文件上传)
    java学习day68-JT项目06(商品curd)
    java学习day67-JT项目05(商品分类树结构显示)
  • 原文地址:https://www.cnblogs.com/hellcat/p/15058984.html
Copyright © 2011-2022 走看看