zoukankan      html  css  js  c++  java
  • BN

    BN

    BN中有一些比较值得注意的地方:

    1. train/test不一致的好处与坏处
    2. 推理中的坑:移动平均。
    3. 训练中的坑:batch的大小与分布。
    4. 微调中的坑:参数化,数据分布等。
    5. 实现中的坑:一个多功能的BN的实现。
    6. GN,precise-BN等等改进。

    BN在训练和测试的时候,行为是不一致的。

    在训练的时候,BN是使用了EMA来进行更新的。在测试的时候,并不是采用了EMA,而是采用了训练时候的统计量。

    1. EMA在(lambda)过于小的时候,EMA并不是合理的近似。
    2. (lambda)过于大的时候,需要很多次迭代。
    3. 模型不稳定的时候,或者是数据不稳定的时候。可能造成一些问题。

    使用Precise-BatchNorm

    继续使用EMA,但是使用比较大的(lambda),把模型固定住。forward很多次迭代。

    Rethinking 'Batch' in batchnormalization这篇paper没怎么读。但是我读了一下precise BN的code:

    为了防止大家对里面的一些函数并不是很熟悉,所以。

    itertools.islice()表示对迭代器进行切片,并且会消耗迭代器。

    running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
    running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)

    这个其实很好理解。这个等价于先求和再取平均。

    #!/usr/bin/env python3
    # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
    
    import itertools
    
    import torch
    
    BN_MODULE_TYPES = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
    )
    
    
    @torch.no_grad()
    def update_bn_stats(model, data_loader, num_iters: int = 200):
        """
        Recompute and update the batch norm stats to make them more precise. During
        training both BN stats and the weight are changing after every iteration, so
        the running average can not precisely reflect the actual stats of the
        current model.
        In this function, the BN stats are recomputed with fixed weights, to make
        the running average more precise. Specifically, it computes the true average
        of per-batch mean/variance instead of the running average.
    
        Args:
            model (nn.Module): the model whose bn stats will be recomputed.
    
                Note that:
    
                1. This function will not alter the training mode of the given model.
                   Users are responsible for setting the layers that needs
                   precise-BN to training mode, prior to calling this function.
    
                2. Be careful if your models contain other stateful layers in
                   addition to BN, i.e. layers whose state can change in forward
                   iterations.  This function will alter their state. If you wish
                   them unchanged, you need to either pass in a submodule without
                   those layers, or backup the states.
            data_loader (iterator): an iterator. Produce data as inputs to the model.
            num_iters (int): number of iterations to compute the stats.
        """
        bn_layers = get_bn_modules(model)
    
        if len(bn_layers) == 0:
            return
    
        # In order to make the running stats only reflect the current batch, the
        # momentum is disabled.
        # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
        # Setting the momentum to 1.0 to compute the stats without momentum.
        momentum_actual = [bn.momentum for bn in bn_layers]
        for bn in bn_layers:
            bn.momentum = 1.0
    
        # Note that running_var actually means "running average of variance"
        running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
        running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
    
        for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
            model(inputs)
    
            for i, bn in enumerate(bn_layers):
                # Accumulates the bn stats.
                running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
                running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
                # We compute the "average of variance" across iterations.
        assert ind == num_iters - 1, (
            "update_bn_stats is meant to run for {} iterations, "
            "but the dataloader stops at {} iterations.".format(num_iters, ind)
        )
    
        for i, bn in enumerate(bn_layers):
            # Sets the precise bn stats.
            bn.running_mean = running_mean[i]
            bn.running_var = running_var[i]
            bn.momentum = momentum_actual[i]
    
    
    def get_bn_modules(model):
        """
        Find all BatchNorm (BN) modules that are in training mode. See
        cvpack2.modeling.nn_utils.precise_bn.BN_MODULE_TYPES for a list of all modules that are
        included in this search.
    
        Args:
            model (nn.Module): a model possibly containing BN modules.
    
        Returns:
            list[nn.Module]: all BN modules in the model.
        """
        # Finds all the bn layers.
        bn_layers = [
            m
            for m in model.modules()
            if m.training and isinstance(m, BN_MODULE_TYPES)
        ]
        return bn_layers
    
    
  • 相关阅读:
    2017年3月10号课堂笔记
    2017年3月8号课堂笔记
    2017年3月6号课堂笔记
    2017年3月3号课堂笔记
    第7讲:设计PE型病毒2
    第6讲:设计PE型病毒1
    第5讲:HOOK 任务管理器 第2种方法注入
    第4讲:HOOK 任务管理器 无法结束进程
    第3讲:导入表的定位和读取操作
    第2讲:搜索PEB结构获取kernel32.dll的基址暴力搜索内存空间获得 Api 的线性地址
  • 原文地址:https://www.cnblogs.com/JohnRan/p/15098398.html
Copyright © 2011-2022 走看看