zoukankan      html  css  js  c++  java
  • 批量归一化

    批量归一化实际上也是一个层,可以直接调用torch的api。

    image-20211118184726584

    我们在训练顶部的层,实际上干的就是在拟合底部层。所以顶部的层会收敛的很快,但是底部的层数据变化后,顶部又要重新进行训练。

    一个简单的实现就是将分布的均值和方差进行固定。这个实现也很简单。

    \[\mu_{B}=\frac{1}{|B|} \sum_{i \in B} x_{i} \text { and } \sigma_{B}^{2}=\frac{1}{|B|} \sum_{i \in B}\left(x_{i}-\mu_{B}\right)^{2}+\epsilon \]

    \[x_{i+1}=\gamma \frac{x_{i}-\mu_{B}}{\sigma_{B}}+\beta \]

    这里的\(\mu\) 均值和 \(\sigma\) 方差是通过batch进行计算得到的,然后在通过归一化操作得到新的输出。其中\(\gamma\)\(\beta\) 是 可训练参数。

    批量归一化层作用于激活函数之前。

    这里批量归一化在卷积里面是作用在通道维的。因为卷积的shape是 (batchsize, channel, high, width) 这里的每一个channel实际上就是一个特征维度。

    如果批量归一化作用在全连接层,那么它是作用在特征维,比如shape (batchsize, features_num) 那么求均值后就是(, features_num)。批量归一化是对批量进行归一化,而不是一个样本里的特征做归一化。反正很难说清楚。

    批量归一化在做什么?

    它可能是通过在小批量里加入噪音来控制模型复杂度,所以没有必要和丢弃法混合使用。批量归一化层只会加速收敛,不会改变模型精度。

    torch代码

    import torch
    import torch.nn as nn
    con = nn.Conv2d(2, 3, 1)
    nor = nn.LazyBatchNorm2d()
    tmp = torch.ones((1, 2, 28, 28))
    s = nor(con(tmp))
    s = s.permute(1,0,2,3).reshape(3,-1)
    

    其实torch里有batchnorm2d,但我搞不懂为什么要输入num_input,这个参数不是隐含在input size里吗?就是size(1),所以这里我用了LazyBatchNorm2d()

    image-20211118192753233

    eps,就是计算batch的\(\sigma\) 方差所用到的 \(\epsilon\) ,默认取0就行了。这里的moment是用于计算running_mean 和running_var用到的。

    \[\hat{x}_{\text {new }}=(1-\text { momentum }) \times \hat{x}+\text { momentum } \times x_{t} \]

    这里的running_mean 和 running_var 就是用该batch得到mean和var进行更新。然后在eval时候,我们就不用batch 的mean和var进行归一化,我们会使用训练得到的running_mean 和 running_var 进行更新。

    如果tack_running_stats 设置为false,那么我们在训练过程就不会去计算running_mean 和running_var。在eval的时候也是用eval的batch mean和var进行归一化。

    affine默认设置为true,表示\(\gamma\)\(\beta\) 是 可训练参数,如果设置为false就不训练。

    这里有个问题:Why does nn.Conv2d require in_channels?

    可以很好的解决我的疑惑,对于卷积核需要初始化,直接在init时候就定义好卷积核大小或许是个好的行为。

    比如在这里:

    import torch
    from torch.functional import norm
    import torch.nn as nn
    con = nn.Conv2d(2, 3, 1)
    nor = nn.LazyBatchNorm2d()
    tmp = torch.ones((1, 2, 28, 28))
    s = nor(con(tmp))
    t = torch.ones((1, 4, 28, 28))
    s = s.permute(1,0,2,3).reshape(3,-1)
    nor(t)
    

    由于使用了LazyNorm先对(1,3,28,28)shape进行归一化,于是在对(1,4,28,28)归一化的时候就会报错。

  • 相关阅读:
    Introduction to XQuery in SQL Server 2005
    [译]Cassandra 架构简述
    冬日绘版实录
    网页实现串口TCP数据通讯的两种方案
    (转)感知哈希算法
    CoreParking
    单线程扫描电脑所有文件与并行计算扫描电脑所有文件所用时间?
    强名称程序集
    一些题(六)
    一些题(五)
  • 原文地址:https://www.cnblogs.com/kalicener/p/15574007.html
Copyright © 2011-2022 走看看