zoukankan      html  css  js  c++  java
  • 『MXNet』第四弹_Gluon自定义层

    https://www.cnblogs.com/hellcat/p/9047618.html

     

    『MXNet』第四弹_Gluon自定义层

     

    一、不含参数层

    通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里,

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    from mxnet import nd, gluon
    from mxnet.gluon import nn
     
    class CenteredLayer(nn.Block):
        def __init__(self**kwargs):
            super(CenteredLayer, self).__init__(**kwargs)
     
        def forward(self, x):
            return - x.mean()
     
    # 直接使用这个层
    layer = CenteredLayer()
    # layer(nd.array([1, 2, 3, 4, 5]))
     
    # 构建更复杂模型
    net = nn.Sequential()
    net.add(nn.Dense(128))
    net.add(nn.Dense(10))
    net.add(CenteredLayer())
     
    # 初始化、运行……
    net.initialize()
    = net(nd.random.uniform(shape=(48)))

    二、含参数层

    注意,本节实现的自定义层不能自动推断输入尺寸,需要手动指定

    见上节『MXNet』第三弹_Gluon模型参数在自定义层的时候我们常使用Block自带的ParameterDict类添加成员变量params,如下,

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    from mxnet import gluon
    from mxnet.gluon import nn
     
    class MyDense(nn.Block):
        def __init__(self, units, in_units, **kwargs):
            super(MyDense, self).__init__(**kwargs)
            self.weight = self.params.get('weight', shape=(in_units, units))
            self.bias = self.params.get('bias', shape=(units,))       
     
        def forward(self, x):
            linear = nd.dot(x, self.weight.data()) + self.bias.data()
            return nd.relu(linear)
     
    # 实际运行
    dense = MyDense(5, in_units=10)

     如果不想使用ParameterDict类则需要一下操作

    1
    2
    3
    # self.weight = self.params.get('weight', shape=(in_units, units))
    self.weight = gluon.Parameter('weight', shape=(in_units, units))
    self.params.update({'weight':self.weight})

    否则在net.initialize()初始化时是初始化不到ParameterDict外变量的。

     有关这一点详见下面:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    def __init__(self, conv_arch, dropout_keep_prob, **kwargs):
        super(SSD, self).__init__(**kwargs)
        self.vgg_conv = nn.Sequential()
        self.vgg_conv.add(repeat(*conv_arch[0], pool=False))
        [self.vgg_conv.add(repeat(*conv_arch[i])) for in range(1len(conv_arch))]
        # 迭代器对象只能进行单次迭代,所以将之转化为tuple,否则识别参数处迭代后forward再次迭代直接跳出循环
        # self.vgg_conv = tuple([repeat(*conv_arch[i])
        #                       for i in range(len(conv_arch))])
        # 只能识别实例属性直接为mx层函数或者mx序列对象的参数,如果使用其他容器,需要将参数收集进参数字典
        # _ = [self.params.update(block.collect_params()) for block in self.vgg_conv]
     
    def forward(self, x, feat_layers):
        end_points = {'block0': x}
        for (index, block) in enumerate(self.vgg_conv):
            end_points.update({'block{:d}'.format(index+1): block(end_points['block{:d}'.format(index)])})
        return end_points

    属性对象是mxnet的对象时才能默认识别层中的参数,否则需要显式收集进self.params中。

    测试代码:

    1
    2
    3
    4
    5
    6
    7
    8
    if __name__ == '__main__':
     
        ssd = SSD(conv_arch=((264), (2128), (3256), (3512), (3512)),
                  dropout_keep_prob=0.5)
        ssd.initialize()
        = mx.ndarray.random.uniform(shape=(11304304))
        import pprint as pp
        pp.pprint([x[1].shape for in ssd(X).items()])

    自行验证即可。

  • 相关阅读:
    drf 之 JWT认证 什么是集群以及分布式 什么是正向代理,什么是反向代理
    drf 之自定制过滤器 分页器(三种)如何使用(重点) 全局异常 封装Response对象 自动生成接口文档
    课堂练习之“寻找最长单词链”
    《人月神话》读书笔记(三)
    用户体验
    第十四周进度报告
    课堂练习之“寻找水王”
    《人月神话》读书笔记(二)
    第二阶段冲刺(十)
    第二阶段冲刺(九)
  • 原文地址:https://www.cnblogs.com/jukan/p/10795074.html
Copyright © 2011-2022 走看看