zoukankan      html  css  js  c++  java
  • 动手学pytorch-Batch Norm

    批量归一化

    1.基本概念

    2.代码实现

    1.基本概念

    对输入的标准化(浅层模型)
    处理后的任意一个特征在数据集中所有样本上的均值为0、标准差为1。
    标准化处理输入数据使各个特征的分布相近
    批量归一化(深度模型)
    利用小批量上的均值和标准差,不断调整神经网络中间输出,从而使整个神经网络在各层的中间输出的数值更稳定。

    1.1对全连接层做批量归一化

    位置:全连接层中的仿射变换和激活函数之间。
    全连接:

    [oldsymbol{x} = oldsymbol{Woldsymbol{u} + oldsymbol{b}} \ output =phi(oldsymbol{x}) ]

    批量归一化:

    [output=phi( ext{BN}(oldsymbol{x})) ]

    [oldsymbol{y}^{(i)} = ext{BN}(oldsymbol{x}^{(i)}) ]

    [oldsymbol{mu}_mathcal{B} leftarrow frac{1}{m}sum_{i = 1}^{m} oldsymbol{x}^{(i)}, ]

    [oldsymbol{sigma}_mathcal{B}^2 leftarrow frac{1}{m} sum_{i=1}^{m}(oldsymbol{x}^{(i)} - oldsymbol{mu}_mathcal{B})^2, ]

    [hat{oldsymbol{x}}^{(i)} leftarrow frac{oldsymbol{x}^{(i)} - oldsymbol{mu}_mathcal{B}}{sqrt{oldsymbol{sigma}_mathcal{B}^2 + epsilon}}, ]

    这⾥ϵ > 0是个很小的常数,保证分母大于0

    [{oldsymbol{y}}^{(i)} leftarrow oldsymbol{gamma} odot hat{oldsymbol{x}}^{(i)} + oldsymbol{eta}. ]

    引入可学习参数:拉伸参数γ和偏移参数β。若(oldsymbol{gamma} = sqrt{oldsymbol{sigma}_mathcal{B}^2 + epsilon})(oldsymbol{eta} = oldsymbol{mu}_mathcal{B}),批量归一化无效。

    1.2对卷积层做批量归⼀化

    位置:卷积计算之后、应⽤激活函数之前。
    如果卷积计算输出多个通道,我们需要对这些通道的输出分别做批量归一化,且每个通道都拥有独立的拉伸和偏移参数。
    计算:对单通道,batchsize=m,卷积计算输出=pxq
    对该通道中m×p×q个元素同时做批量归一化,使用相同的均值和方差。

    1.3预测时的批量归⼀化

    训练:以batch为单位,对每个batch计算均值和方差。
    预测:用移动平均估算整个训练数据集的样本均值和方差。

    2.代码实现

    class BatchNorm(nn.Module):
        def __init__(self, *, num_features, num_dims):
            super(BatchNorm, self).__init__()
            super(BatchNorm, self).__init__()
            if num_dims == 2:
                shape = (1, num_features) #全连接层输出神经元
            else:
                shape = (1, num_features, 1, 1)  #通道数
            # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成0和1
            self.gamma = nn.Parameter(torch.ones(shape))
            self.beta = nn.Parameter(torch.zeros(shape))
            # 不参与求梯度和迭代的变量,全在内存上初始化成0
            self.moving_mean = torch.zeros(shape)
            self.moving_var = torch.zeros(shape)
            self.momentum = 0.9
        
        def forward(self, X):
            if self.moving_mean.device != X.device:
                self.moving_mean = self.moving_mean.to(X.device)
                self.moving_var = self.moving_var.to(X.device)
            # 保存更新过的moving_mean和moving_var, Module实例的traning属性默认为true, 调用.eval()后设成false
            Y, self.moving_mean, self.moving_var = self._batch_norm(self.training, 
                X, self.gamma, self.beta, self.moving_mean,
                self.moving_var, eps=1e-5, momentum=self.momentum)
            return Y
    
        def _batch_norm(self, is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
            if not is_training:
                # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
                X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
            else:
                assert len(X.shape) in (2, 4)
                if len(X.shape) == 2:
                    # 使用全连接层的情况,计算特征维上的均值和方差
                    mean = X.mean(dim=0)
                    var = ((X - mean) ** 2).mean(dim=0)
                else:
                    # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要保持
                    # X的形状以便后面可以做广播运算
                    mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
                    var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
                # 训练模式下用当前的均值和方差做标准化
                X_hat = (X - mean) / torch.sqrt(var + eps)
                # 更新移动平均的均值和方差
                moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
                moving_var = momentum * moving_var + (1.0 - momentum) * var
            Y = gamma * X_hat + beta  # 拉伸和偏移
            return Y, moving_mean, moving_var
    
    

    带batch norm 的LeNet

    class BLeNet(nn.Module):
        def __init__(self, *, channels, fig_size, num_class):
            super(BLeNet, self).__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(channels, 6, 5, padding=2),
                BatchNorm(num_features=6, num_dims = 4),
                nn.Sigmoid(),
                nn.AvgPool2d(2, 2),
                nn.Conv2d(6, 16, 5),
                BatchNorm(num_features=16, num_dims = 4),
                nn.Sigmoid(),
                nn.AvgPool2d(2, 2),
            )
            ##经过卷积和池化层后的图像大小
            fig_size = (fig_size - 5 + 1 + 4 ) // 1
            fig_size = (fig_size - 2 + 2) // 2
            fig_size = (fig_size - 5 + 1) // 1
            fig_size = (fig_size - 2 + 2) // 2
            self.fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(16 * fig_size * fig_size, 120),
                BatchNorm(num_features=120, num_dims = 2),
                nn.Sigmoid(),
                nn.Linear(120, 84),
                BatchNorm(num_features=84, num_dims = 2),
                nn.Sigmoid(),
                nn.Linear(84, num_class),
            )
        def forward(self, X):
            conv_features = self.conv(X)
            output = self.fc(conv_features)
            return output
    
    
  • 相关阅读:
    yb课堂之自定义异常和配置 《五》
    文件包含总结--2018自我整理
    文件上传总结--2018自我整理
    i春秋 “百度杯”CTF比赛 十月场 web题 Backdoor
    bugku web题INSERT INTO注入
    SCTF2018-Event easiest web
    初识thinkphp(5)
    “百度杯”CTF比赛 九月场 YeserCMS
    初识thinkphp(4)
    0MQ是会阻塞的,不要字面上看到队列就等同非阻塞。
  • 原文地址:https://www.cnblogs.com/54hys/p/12335020.html
Copyright © 2011-2022 走看看