zoukankan      html  css  js  c++  java
  • 自监督图像论文复现 | BYOL(pytorch)| 2020

    继续上一篇的内容,上一篇讲解了Bootstrap Your Onw Latent自监督模型的论文和结构:
    https://juejin.cn/post/6922347006144970760

    现在我们看看如何用pytorch来实现这个结构,并且在学习的过程中加深对论文的理解。
    github:https://github.com/lucidrains/byol-pytorch

    【前沿】:这个代码我没有实际跑过,毕竟我只是一个没有GPU的小可怜。

    主要模型代码

    class BYOL(nn.Module):
        def __init__(
            self,
            net,
            image_size,
            hidden_layer = -2,
            projection_size = 256,
            projection_hidden_size = 4096,
            augment_fn = None,
            augment_fn2 = None,
            moving_average_decay = 0.99,
            use_momentum = True
        ):
            super().__init__()
            self.net = net
    
            # default SimCLR augmentation
    
            DEFAULT_AUG = torch.nn.Sequential(
                RandomApply(
                    T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                    p = 0.3
                ),
                T.RandomGrayscale(p=0.2),
                T.RandomHorizontalFlip(),
                RandomApply(
                    T.GaussianBlur((3, 3), (1.0, 2.0)),
                    p = 0.2
                ),
                T.RandomResizedCrop((image_size, image_size)),
                T.Normalize(
                    mean=torch.tensor([0.485, 0.456, 0.406]),
                    std=torch.tensor([0.229, 0.224, 0.225])),
            )
    
            self.augment1 = default(augment_fn, DEFAULT_AUG)
            self.augment2 = default(augment_fn2, self.augment1)
    
            self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
    
            self.use_momentum = use_momentum
            self.target_encoder = None
            self.target_ema_updater = EMA(moving_average_decay)
    
            self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
    
            # get device of network and make wrapper same device
            device = get_module_device(net)
            self.to(device)
    
            # send a mock image tensor to instantiate singleton parameters
            self.forward(torch.randn(2, 3, image_size, image_size, device=device))
    
        @singleton('target_encoder')
        def _get_target_encoder(self):
            target_encoder = copy.deepcopy(self.online_encoder)
            set_requires_grad(target_encoder, False)
            return target_encoder
    
        def reset_moving_average(self):
            del self.target_encoder
            self.target_encoder = None
    
        def update_moving_average(self):
            assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
            assert self.target_encoder is not None, 'target encoder has not been created yet'
            update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
    
        def forward(self, x, return_embedding = False):
            if return_embedding:
                return self.online_encoder(x)
    
            image_one, image_two = self.augment1(x), self.augment2(x)
    
            online_proj_one, _ = self.online_encoder(image_one)
            online_proj_two, _ = self.online_encoder(image_two)
    
            online_pred_one = self.online_predictor(online_proj_one)
            online_pred_two = self.online_predictor(online_proj_two)
    
            with torch.no_grad():
                target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
                target_proj_one, _ = target_encoder(image_one)
                target_proj_two, _ = target_encoder(image_two)
                target_proj_one.detach_()
                target_proj_two.detach_()
    
            loss_one = loss_fn(online_pred_one, target_proj_two.detach())
            loss_two = loss_fn(online_pred_two, target_proj_one.detach())
    
            loss = loss_one + loss_two
            return loss.mean()
    
    • 先看forward()函数,发现输入一个图片给模型,然后返回值是这个图片计算的loss
    • 如果是推理过程,那么return_embedding=True,那么返回的值就是online network中的encoder部分输出的东西,不用在考虑后面的predictor,这里需要注意代码中的encoder其实是论文中的encoder+projector
    • 图片经过self.augment1和self.augment2处理成两个不同的图片,在上一篇中,我们称之为view;
    • 两个图片都经过online-encoder,这里可能会有疑问:不是应该一个图片经过online network,另外一个经过target network吗?为什么这两个都经过online-encoder,你说的没错,这里只是方便后面计算symmetric loss,因为要计算对称损失,所以两个图片都要经过online network和target network。
    • 在target network中推理的内容,都不需要记录梯度,因为target network是根据online network的参数更新的
    • 如果self.use_momentum=False,那么就不使用论文中的更新target network的方式,而是直接把online network复制给target network,不过我发现!这个github代码虽然有600多stars,但是这里的就算你的self.use_momentum=True,其实也是把online network复制给了target network啊哈哈,那么就不在这里深究了。
    • 最后计算通过loss_fn计算损失,然后return loss.mean()

    所以,目前位置,我们发现这个BYOL的结构其实很简单,目前还有疑点的地方有4个:

    • online_encoder如何定义?
    • predictor如何定义?
    • 图像增强方法如何定义?
    • loss_fn损失函数如何定义?

    augment

    从上面的代码中可以看到这一段:

    # default SimCLR augmentation
    
            DEFAULT_AUG = torch.nn.Sequential(
                RandomApply(
                    T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                    p = 0.3
                ),
                T.RandomGrayscale(p=0.2),
                T.RandomHorizontalFlip(),
                RandomApply(
                    T.GaussianBlur((3, 3), (1.0, 2.0)),
                    p = 0.2
                ),
                T.RandomResizedCrop((image_size, image_size)),
                T.Normalize(
                    mean=torch.tensor([0.485, 0.456, 0.406]),
                    std=torch.tensor([0.229, 0.224, 0.225])),
            )
    
            self.augment1 = default(augment_fn, DEFAULT_AUG)
            self.augment2 = default(augment_fn2, self.augment1)
    

    可以看到:

    • 这个就是图像增强的pipeline,而augment1和augment2可以自定义,默认的话就是augment1和augment2都是上面的DEFAULT_AUG;
    • from torchvision import transforms as T

    比较陌生的可能就是torchvision.transforms.ColorJitter()这个方法了。

    从官方API上可以看到,这个方法其实就是随机的修改图片的亮度,对比度,饱和度和色调

    encoder+projector

    class NetWrapper(nn.Module):
        def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
            super().__init__()
            self.net = net
            self.layer = layer
    
            self.projector = None
            self.projection_size = projection_size
            self.projection_hidden_size = projection_hidden_size
    
            self.hidden = None
            self.hook_registered = False
    
        def _find_layer(self):
            if type(self.layer) == str:
                modules = dict([*self.net.named_modules()])
                return modules.get(self.layer, None)
            elif type(self.layer) == int:
                children = [*self.net.children()]
                return children[self.layer]
            return None
    
        def _hook(self, _, __, output):
            self.hidden = flatten(output)
    
        def _register_hook(self):
            layer = self._find_layer()
            assert layer is not None, f'hidden layer ({self.layer}) not found'
            handle = layer.register_forward_hook(self._hook)
            self.hook_registered = True
    
        @singleton('projector')
        def _get_projector(self, hidden):
            _, dim = hidden.shape
            projector = MLP(dim, self.projection_size, self.projection_hidden_size)
            return projector.to(hidden)
    
        def get_representation(self, x):
            if self.layer == -1:
                return self.net(x)
    
            if not self.hook_registered:
                self._register_hook()
    
            _ = self.net(x)
            hidden = self.hidden
            self.hidden = None
            assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
            return hidden
    
        def forward(self, x, return_embedding = False):
            representation = self.get_representation(x)
    
            if return_embedding:
                return representation
    
            projector = self._get_projector(representation)
            projection = projector(representation)
            return projection, representation
    

    这个就是基本的encoder+projector,里面包含encoder和projector。

    encoder

    这个在初始化NetWrapper的时候,需要作为参数传递进来,所以看了训练文件,发现这个模型为:

    from torchvision import models, transforms
    resnet = models.resnet50(pretrained=True)
    

    所以encoder和论文中说的一样,是一个resnet50。如果我记得没错,这个resnet输出的是一个(batch_size,1000)这样子的tensor。

    projector

    调用到了MLP这个东西:

    class MLP(nn.Module):
        def __init__(self, dim, projection_size, hidden_size = 4096):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(dim, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_size, projection_size)
            )
    
        def forward(self, x):
            return self.net(x)
    

    是全连接层+BN+激活层的结构。和论文中说的差不多,并且在最后的全连接层后面没有加上BN+relu。经过这个MLP,返回的是一个(batch_size,projection_size)这样形状的tensor。

    predictor

    self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
    

    这个predictor,其实就是和projector一模一样的东西,可以看到predictor的输入和输出的特征数量都是projection_size

    这里因为我对自监督的体系没有完整的阅读论文,只是最先看了这个BYOL,所以我无法说明这个predictor为什么存在。从表现来看,是为了防止online network和target network的结构完全相同,如果完全相同的话可能会让两个模型训练出完全一样的效果,也就是loss=0的情况。假设

    loss_fn

    def loss_fn(x, y):
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        return 2 - 2 * (x * y).sum(dim=-1)
    

    这部分和论文中一致。

    综上所属,这个BYOL框架是一个简单,又有趣的无监督架构。

    人不可傲慢。
  • 相关阅读:
    WPF 关于拖拽打开文件的注意事项
    asp.net core 3.1中对Mongodb BsonDocument的序列化和反序列化支持
    用百度webuploader分片上传大文件
    多线程学习笔记
    web.config数据库连接字符串加密
    Visual Studio 2010 常用快捷方式
    Team Foundation Server 2013 日常使用使用手册(四)分支与合并
    Team Foundation Server 2013 日常使用使用手册(三)上传新工程、创建任务、创建bug、设置预警
    Team Foundation Server 2013 日常使用使用手册(二)修改、签入、撤销、回滚、对比代码变更
    Team Foundation Server 2013 日常使用使用手册(一)-本地连接TFS、查看任务
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/14349979.html
Copyright © 2011-2022 走看看