zoukankan      html  css  js  c++  java
  • VIT Vision Transformer | 先从PyTorch代码了解


    • 文章原创自:微信公众号「机器学习炼丹术」
    • 作者:炼丹兄
    • 联系方式:微信cyx645016617

    • 代码来自github

    【前言】:看代码的时候,也许会不理解VIT中各种组件的含义,但是这个文章的目的是了解其实现。在之后看论文的时候,可以做到心中有数,而不是一片茫然。

    VIT类

    初始化

    和之前的学习一样,从大模型类开始看起,然后一点一点看小模型类:

    class ViT(nn.Module):
        def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
            super().__init__()
            assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
            num_patches = (image_size // patch_size) ** 2
            patch_dim = channels * patch_size ** 2
            assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
            assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    
            self.patch_size = patch_size
    
            self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
            self.patch_to_embedding = nn.Linear(patch_dim, dim)
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.dropout = nn.Dropout(emb_dropout)
    
            self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
    
            self.pool = pool
            self.to_latent = nn.Identity()
    
            self.mlp_head = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, num_classes)
            )
    

    在实际的调用中,是如下调用的:

    model = ViT(
        dim=128,
        image_size=224,
        patch_size=32,
        num_classes=2,
        channels=3,
    ).to(device)
    

    输入参数讲解:

    • image_size:图片的大小;
    • patch_size:把图片划分成小的patch,小的patch的尺寸;
    • num_classes:这次分类任务的类别总数;
    • channels:输入图片的通道数。

    VIT类中初始化的组件:

    • num_patches:一个图片划分成多少个patch,因为图片224,patch32,所以划分成7x7=49个patches;
    • patch_dim:3x32x32,理解为一个patch中的元素个数;

    ......这样展示是不是非常的麻烦,还要上下来回翻看代码,所以我写成注释的形式

    class ViT(nn.Module):
        def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # image_size=224,patch_size=32,num_classes=2,channels=3,dim=128
            super().__init__()
            assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
            # num_pathes = (224//32)**2 = 7*7=49
            num_patches = (image_size // patch_size) ** 2
            # patch_dim = 3*32*32
            patch_dim = channels * patch_size ** 2
            assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
            assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    		# self.patch_size = 32
            self.patch_size = patch_size
            # self.pos_embedding是一个形状为(1,50,128)
            self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
            # self.patch_to_embedding是一个从3*32*32到128映射的线性层
            self.patch_to_embedding = nn.Linear(patch_dim, dim)
            # self.cls_token是一个随机初始化的形状为(1,1,128)这样的变量
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.dropout = nn.Dropout(emb_dropout)
            
            # Transformer后面会讲解
            self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
    
            self.pool = pool
            self.to_latent = nn.Identity()
    
            self.mlp_head = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, num_classes)
            )
    

    forward

    现在看VIT的推理过程:

        def forward(self, img, mask = None):
     		# p=32
            p = self.patch_size
            x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
            x = self.patch_to_embedding(x) # x.shape=[b,49,128]
            b, n, _ = x.shape # n = 49
    
            cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
            x = torch.cat((cls_tokens, x), dim=1) # x.shape=[b,50,128]
            x += self.pos_embedding[:, :(n + 1)] # x.shape=[b,50,128]
            x = self.dropout(x) 
    
            x = self.transformer(x, mask) # x.shape=[b,50,128],mask=None
    
            x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
    
            x = self.to_latent(x)
            return self.mlp_head(x)
    
    • 这里的代码用到了from einops import rearrange, repeat,这个库函数,einops是一个库函数,是对张量进行操作的库函数,支持pytorch,TF等。
    • einops.rearrange是把输入的img,从[b,3,224,224]的形状改成[b,3,7,32,7,32]的形状,通过矩阵的转置换成[b,7,7,32,32,3]的样子,最后合并成[b,49,32x32x3]
    • self.patch_to_embedding,输出的x的形状为[b,49,128];
    • einops.repeat是把self.cls_token从[1,1,128]复制成[b,1,128]

    现在,我们知道从patch到embedding是用线性层实现的。

    transformer

    class Transformer(nn.Module):
        def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
            # dim=128,depth=12,heads=8,dim_head=64,mlp_dim=128
            super().__init__()
            self.layers = nn.ModuleList([])
            for _ in range(depth):
                self.layers.append(nn.ModuleList([
                    Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                    Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
                ]))
        def forward(self, x, mask = None):
            for attn, ff in self.layers:
                x = attn(x, mask = mask)
                x = ff(x)
            return x
    
    • self.layers中包含depth组的Attention+FeedForward模块。
    • 这里需要记得,输入的x的尺寸为[b,50,128]

    Attention

    class Attention(nn.Module):
        def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
            super().__init__()
            inner_dim = dim_head *  heads # 64 x 8
            self.heads = heads # 8
            self.scale = dim_head ** -0.5
    
            self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
            self.to_out = nn.Sequential(
                nn.Linear(inner_dim, dim),
                nn.Dropout(dropout)
            )
    
        def forward(self, x, mask = None):
            b, n, _, h = *x.shape, self.heads # n=50,h=8
            # self.to_qkv(x)得到的尺寸为[b,50,64x8x3],然后chunk成3份
            # 也就是说,qkv是一个三元tuple,每一份都是[b,50,64x8]的大小
            qkv = self.to_qkv(x).chunk(3, dim = -1)
            # 把每一份从[b,50,64x8]变成[b,8,50,64]的形式
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
    		# 这一步不太好理解,q和k都是[b,8,50,64]的形式,50理解为特征数量,64为特征变量
            # dots.shape=[b,8,50,50]
            dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
            # 不考虑mask这一块的内容
            mask_value = -torch.finfo(dots.dtype).max
    
            if mask is not None:
                mask = F.pad(mask.flatten(1), (1, 0), value = True)
                assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
                mask = mask[:, None, :] * mask[:, :, None]
                dots.masked_fill_(~mask, mask_value)
                del mask
    		# 对[b,8,50,50]的最后一个维度做softmax
            attn = dots.softmax(dim=-1)
    
    		# 这个attn就是计算出来的自注意力值,和v做点乘,out.shape=[b,8,50,64]
            out = torch.einsum('bhij,bhjd->bhid', attn, v)
            # out.shape变成[b,50,8x64]
            out = rearrange(out, 'b h n d -> b n (h d)')
            # out.shape重新变成[b,60,128]
            out =  self.to_out(out)
            return out
    

    综上所属,这个attention其实就是一个自注意力模块,输入的是[b,50,128],返回的也是[b,50,128]。实现的过程因为使用了torch.einsum所以有些复杂,但是总的来说,和我之前讲过的一篇论文"non-local"模块,是完全一样的。torch.einsum和torch.mm原理相同,只是因为torch.mm不支持高纬度的张量做矩阵乘法。

    PreNorm

    class PreNorm(nn.Module):
        def __init__(self, dim, fn):
        # dim=128,fn=Attention/FeedForward
            super().__init__()
            self.norm = nn.LayerNorm(dim)
            self.fn = fn
        def forward(self, x, **kwargs):
            return self.fn(self.norm(x), **kwargs)
    

    先对输入的x(x.shape=[b,50,128])做一个layerNormalization层归一化,然后再放到上面的Attention模块中做自注意力。

    Residual

    class Residual(nn.Module):
        def __init__(self, fn):
            super().__init__()
            self.fn = fn
        def forward(self, x, **kwargs):
            return self.fn(x, **kwargs) + x
    

    一个残差模块罢了。

    FeedForward

    class FeedForward(nn.Module):
        def __init__(self, dim, hidden_dim, dropout = 0.):
        # dim=128,hidden_dim=128
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, dim),
                nn.Dropout(dropout)
            )
        def forward(self, x):
            return self.net(x)
    

    就是两个线性层,这里有意思的是GELU()激活函数,这个激活函数可以直接使用torch.nn.GELU()调用,回头有机会再好好讲一下GELU()的原理。

    transformer总结

    Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
    Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
    
    • 第一个就是,先对输入做layerNormalization,然后放到attention得到attention的结果,然后结果和做layerNormalization之前的输入相加做一个残差链接;
    • 第二个就是,x->LayerNormalization->FeedForward线性层->y,然后这个y和输入的x相加,做残差连接。

    VIT总结

    回顾一下整个流程:

    • 一个图片224x224,分成了49个32x32的patch;
    • 对这么多的patch做embedding,成49个128向量;
    • 再拼接一个cls_tokens,变成50个128向量;
    • 再加上pos_embedding,还是50个128向量;
    • 这些向量输入到transformer中进行自注意力的特征提取;
    • 输出的是50个128向量,然后对这个50个求军职,变成一个128向量;
    • 然后线性层把128维变成2维从而完成二分类任务的transformer模型。

    问题:我对NLP了解不深入,有没有人可以回答一下这个问题:cls_tokens和pos_embedding的用处是什么?

    人不可傲慢。
  • 相关阅读:
    【Javascript】javascript学习 二十二 JavaScript 对象简介
    【Javascript】javascript学习 二十六 JavaScript Boolean(逻辑)对象
    【Javascript】javascript学习 二十九 JavaScript HTML DOM 对象
    【Javascript】javascript学习 二十八 JavaScript RegExp 对象
    【Javascript】javascript学习 二十一 JavaScript 指导方针
    【Javascript】javascript学习 二十三 JavaScript 字符串(String)对象
    【Javascript】javascript学习 三十 JavaScript 浏览器检测
    【Javascript】javascript学习 二十五 JavaScript Array(数组)对象
    【Javascript】javascript学习 二十四 JavaScript Date(日期)对象
    【Javascript】javascript学习 二十七 JavaScript Math(算数)对象
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/14367000.html
Copyright © 2011-2022 走看看