zoukankan      html  css  js  c++  java
  • 『MXNet』第四弹_Gluon自定义层

    https://www.cnblogs.com/hellcat/p/9047618.html

     

    『MXNet』第四弹_Gluon自定义层

     

    一、不含参数层

    通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里,

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    from mxnet import nd, gluon
    from mxnet.gluon import nn
     
    class CenteredLayer(nn.Block):
        def __init__(self**kwargs):
            super(CenteredLayer, self).__init__(**kwargs)
     
        def forward(self, x):
            return - x.mean()
     
    # 直接使用这个层
    layer = CenteredLayer()
    # layer(nd.array([1, 2, 3, 4, 5]))
     
    # 构建更复杂模型
    net = nn.Sequential()
    net.add(nn.Dense(128))
    net.add(nn.Dense(10))
    net.add(CenteredLayer())
     
    # 初始化、运行……
    net.initialize()
    = net(nd.random.uniform(shape=(48)))

    二、含参数层

    注意,本节实现的自定义层不能自动推断输入尺寸,需要手动指定

    见上节『MXNet』第三弹_Gluon模型参数在自定义层的时候我们常使用Block自带的ParameterDict类添加成员变量params,如下,

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    from mxnet import gluon
    from mxnet.gluon import nn
     
    class MyDense(nn.Block):
        def __init__(self, units, in_units, **kwargs):
            super(MyDense, self).__init__(**kwargs)
            self.weight = self.params.get('weight', shape=(in_units, units))
            self.bias = self.params.get('bias', shape=(units,))       
     
        def forward(self, x):
            linear = nd.dot(x, self.weight.data()) + self.bias.data()
            return nd.relu(linear)
     
    # 实际运行
    dense = MyDense(5, in_units=10)

     如果不想使用ParameterDict类则需要一下操作

    1
    2
    3
    # self.weight = self.params.get('weight', shape=(in_units, units))
    self.weight = gluon.Parameter('weight', shape=(in_units, units))
    self.params.update({'weight':self.weight})

    否则在net.initialize()初始化时是初始化不到ParameterDict外变量的。

     有关这一点详见下面:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    def __init__(self, conv_arch, dropout_keep_prob, **kwargs):
        super(SSD, self).__init__(**kwargs)
        self.vgg_conv = nn.Sequential()
        self.vgg_conv.add(repeat(*conv_arch[0], pool=False))
        [self.vgg_conv.add(repeat(*conv_arch[i])) for in range(1len(conv_arch))]
        # 迭代器对象只能进行单次迭代,所以将之转化为tuple,否则识别参数处迭代后forward再次迭代直接跳出循环
        # self.vgg_conv = tuple([repeat(*conv_arch[i])
        #                       for i in range(len(conv_arch))])
        # 只能识别实例属性直接为mx层函数或者mx序列对象的参数,如果使用其他容器,需要将参数收集进参数字典
        # _ = [self.params.update(block.collect_params()) for block in self.vgg_conv]
     
    def forward(self, x, feat_layers):
        end_points = {'block0': x}
        for (index, block) in enumerate(self.vgg_conv):
            end_points.update({'block{:d}'.format(index+1): block(end_points['block{:d}'.format(index)])})
        return end_points

    属性对象是mxnet的对象时才能默认识别层中的参数,否则需要显式收集进self.params中。

    测试代码:

    1
    2
    3
    4
    5
    6
    7
    8
    if __name__ == '__main__':
     
        ssd = SSD(conv_arch=((264), (2128), (3256), (3512), (3512)),
                  dropout_keep_prob=0.5)
        ssd.initialize()
        = mx.ndarray.random.uniform(shape=(11304304))
        import pprint as pp
        pp.pprint([x[1].shape for in ssd(X).items()])

    自行验证即可。

  • 相关阅读:
    Java 第十一届 蓝桥杯 省模拟赛 梅花桩
    Java 第十一届 蓝桥杯 省模拟赛 梅花桩
    Java 第十一届 蓝桥杯 省模拟赛 梅花桩
    Java 第十一届 蓝桥杯 省模拟赛 元音字母辅音字母的数量
    Java 第十一届 蓝桥杯 省模拟赛 元音字母辅音字母的数量
    Java 第十一届 蓝桥杯 省模拟赛 元音字母辅音字母的数量
    Java 第十一届 蓝桥杯 省模拟赛 最大的元素距离
    Java 第十一届 蓝桥杯 省模拟赛 递增序列
    Java 第十一届 蓝桥杯 省模拟赛 递增序列
    Java 第十一届 蓝桥杯 省模拟赛 最大的元素距离
  • 原文地址:https://www.cnblogs.com/jukan/p/10795074.html
Copyright © 2011-2022 走看看