zoukankan      html  css  js  c++  java
  • MXNET:深度学习计算-模型构建

    进入更深的层次:模型构造、参数访问、自定义层和使用 GPU。

    模型构建

    在多层感知机的实现中,我们首先构造 Sequential 实例,然后依次添加两个全连接层。其中第一层的输出大小为 256,即隐藏层单元个数是 256;第二层的输出大小为 10,即输出层单元个数是 10。

    我们之前都是用了 Sequential 类来构造模型。这里我们另外一种基于 Block 类的模型构造方法,它让构造模型更加灵活,也将让你能更好的理解 Sequential 的运行机制。

    继承 Block 类来构造模型

    Block 类是 gluon.nn 里提供的一个模型构造类,我们可以继承它来定义我们想要的模型。例如,我们在这里构造一个同前提到的相同的多层感知机。这里定义的 MLP 类重载了 Block 类的两个函数:init 和 forward.

    from mxnet import nd
    from mxnet.gluon import nn
    
    class MLP(nn.Block):
        # 声明带有模型参数的层,这里我们声明了两个全链接层。
        def __init__(self, **kwargs):
            # 调用 MLP 父类 Block 的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数参数,例如后面将介绍的模型参数 params。
            super(MLP, self).__init__(**kwargs)
            # 隐藏层。
            self.hidden = nn.Dense(256, activation='relu')
            # 输出层。
            self.output = nn.Dense(10)
        # 定义模型的前向计算,即如何根据输出计算输出。
        def forward(self, x):
            return self.output(self.hidden(x))
    

    我们可以实例化 MLP 类得到 net

    x = nd.random.uniform(shape=(2,20))
    net = MLP()
    net.initialize()
    net(x)
    

    其中,net(x) 会调用了 MLP 继承至 Block 的 call 函数,这个函数将调用 MLP 定义的 forward 函数来完成前向计算。

    我们无需在这里定义反向传播函数,系统将通过自动求导,来自动生成 backward 函数。

    注意到我们不是将 Block 叫做层或者模型之类的名字,这是因为它是一个可以自由组建的部件。它的子类既可以一个层,例如 Gluon 提供的 Dense 类,也可以是一个模型,我们定义的 MLP 类,或者是模型的一个部分,例如我们会在之后介绍的 ResNet 的残差块。

    Sequential 类继承自 Block 类
    当模型的前向计算就是简单串行计算模型里面各个层的时候,我们可以将模型定义变得更加简单,这个就是 Sequential 类的目的,它通过 add 函数来添加 Block 子类实例,前向计算时就是将添加的实例逐一运行。下面我们实现一个跟 Sequential 类有相同功能的类,这样你可以看的更加清楚它的运行机制。

    class MySequential(nn.Block):
        def __init__(self, **kwargs):
            super(MySequential, self).__init__(**kwargs)
    
        def add(self, block):
            # block 是一个 Block 子类实例,假设它有一个独一无二的名字。我们将它保存在Block 类的成员变量 _children 里,其类型是 OrderedDict. 
            #当调用initialize 函数时,系统会自动对 _children 里面所有成员初始化。
            self._children[block.name] = block
    
        def forward(self, x):
            # OrderedDict 保证会按照插入时的顺序遍历元素。
            for block in self._children.values():
                x = block(x)
            return x
    

    使用:

    net = MySequential()
    net.add(nn.Dense(256, activation='relu'))
    net.add(nn.Dense(10))
    net.initialize()
    net(x)
    

    构造复杂的模型
    我们构造一个稍微复杂点的网络。在这个网络中,我们通过 get_constant 函数创建训练中不被迭代的参数,即常数参数。在前向计算中,除了使用创建的常数参数外,我们还使用 NDArray 的函数和 Python 的控制流,并多次调用同一层。

    
    class FancyMLP(nn.Block):
        def __init__(self, **kwargs):
            super(FancyMLP, self).__init__(**kwargs)
            # 使用 get_constant 创建的随机权重参数不会在训练中被迭代(即常数参数)。
            self.rand_weight = self.params.get_constant(
                'rand_weight', nd.random.uniform(shape=(20, 20)))
            self.dense = nn.Dense(20, activation='relu')
    
        def forward(self, x):
            x = self.dense(x)
            # 使用创建的常数参数,以及 NDArray 的 relu 和 dot 函数。
            x = nd.relu(nd.dot(x, self.rand_weight.data()) + 1)
            # 重用全连接层。等价于两个全连接层共享参数。
            x = self.dense(x)
            # 控制流,这里我们需要调用 asscalar 来返回标量进行比较。
            while x.norm().asscalar() > 1:
                x /= 2
            if x.norm().asscalar() < 0.8:
                x *= 10
            return x.sum()
    

    使用:

    net = FancyMLP()
    net.initialize()
    net(x)
    

    由于 FancyMLP 和 Sequential 都是 Block 的子类,我们可以嵌套调用他们。

    class NestMLP(nn.Block):
        def __init__(self, **kwargs):
            super(NestMLP, self).__init__(**kwargs)
            self.net = nn.Sequential()
            self.net.add(nn.Dense(64, activation='relu'),
                         nn.Dense(32, activation='relu'))
            self.dense = nn.Dense(16, activation='relu')
    
        def forward(self, x):
            return self.dense(self.net(x))
    
    net = nn.Sequential()
    net.add(NestMLP(), nn.Dense(20), FancyMLP())
    
    net.initialize()
    net(x)
    
  • 相关阅读:
    网络传输协议 UDP & TCP 详解
    OSI 七层协议
    (01day)python接口测试
    Python2和Python3的区别,以及为什么选Python3的原因
    JAVA反编译工具
    JAR反编译工具
    webdriver19-witchto方法
    webdriver实例14-Xpath定位的几种方法
    webdirver实例1--查找元素
    Qt插件开发
  • 原文地址:https://www.cnblogs.com/houkai/p/9521975.html
Copyright © 2011-2022 走看看