zoukankan      html  css  js  c++  java
  • 莫烦课程Batch Normalization 批标准化

     for i in range(N_HIDDEN):               # build hidden layers and BN layers
                input_size = 1 if i == 0 else 10
                fc = nn.Linear(input_size, 10)
                setattr(self, 'fc%i' % i, fc)       # IMPORTANT set layer to the Module
                self._set_init(fc)                  # parameters initialization
                self.fcs.append(fc)
                if self.do_bn:
                    bn = nn.BatchNorm1d(10, momentum=0.5)
                    setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module
    self.bns.append(bn)
    

     上面的代码对每个隐层进行批标准化,setattr(self, 'fc%i' % i, fc)作用相当于self.fci=fc

    每次生成的结果append到bns的最后面,结果的size 10×10,取出这些数据是非常方便

    def forward(self, x):
            pre_activation = [x]
            if self.do_bn: x = self.bn_input(x)     # input batch normalization
            layer_input = [x]
            for i in range(N_HIDDEN):
                x = self.fcs[i](x)
                pre_activation.append(x)
                if self.do_bn: x = self.bns[i](x)   # batch normalization
                x = ACTIVATION(x)
                layer_input.append(x)
            out = self.predict(x)
    return out, layer_input, pre_activation
    

    全部的源代码

  • 相关阅读:
    第 12 章 Docker Swarm
    第 1 章 虚拟化
    第 0 章 写在最前面
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
  • 原文地址:https://www.cnblogs.com/lindaxin/p/8034069.html
Copyright © 2011-2022 走看看