zoukankan      html  css  js  c++  java
  • 4.3 模型参数的延后初始化

    模型参数的延后初始化

    延后初始化

    在上一节使用的多层感知机net里,我们创建的隐藏层仅仅指定了输出大小为256。当调用initialize函数时,由于隐藏层输入个数依然未知,系统也无法得知该层权重参数的形状。只有在当我们将形状是(2, 20)的输入(X)传进网络做前向计算net(X)时,系统才推断出该层的权重参数形状为(256, 20)。因此,这时候我们才能真正开始初始化参数。
    使用上一节中定义的MyInit类来演示这一过程。创建多层感知机,并使用MyInit实例来初始化模型参数。

    #导包
    from mxnet import init, nd
    from mxnet.gluon import nn
    #定义上一节的初始化函数
    class MyInit(init.Initializer):
        def _init_weight(self, name, data):
            print('Init', name, data.shape)
            # 实际的初始化逻辑在此省略了
    #实例化
    net = nn.Sequential()
    #添加隐藏层,输出层
    net.add(nn.Dense(256, activation='relu'),
            nn.Dense(10))
    #初始化
    net.initialize(init=MyInit())
    

    initialize函数时并没有真正初始化参数。
    定义输入并执行一次前向计算:

    #定义输入
    X = nd.random.uniform(shape=(2, 20))
    #使用net进行正向计算
    Y = net(X)
    #输出:
    #在调用net(x)时才进行初始化参数
    Init dense0_weight (256, 20)
    Init dense1_weight (10, 256)
    

    这时候,有关模型参数的信息被打印出来。在根据输入(X)做前向计算时,系统能够根据输入的形状自动推断出所有层的权重参数的形状
    系统在创建这些参数之后,调用MyInit实例对它们进行初始化,然后才进行前向计算。这个初始化只会在第一次前向计算时被调用。之后我们再运行前向计算net(X)时则不会重新初始化
    系统将真正的参数初始化延后到获得足够信息时才执行的行为叫作延后初始化(deferred initialization),这里的足够信息是指需要知道输入的shape.它可以让模型的创建更加简单:只需要定义每个层的输出大小,而不用人工推测它们的输入个数。这对于之后将介绍的定义多达数十甚至数百层的网络来说尤其方便。
    但是在第一次前向计算之前,我们无法直接操作模型参数,例如无法使用data函数和set_data函数来获取和修改参数。因此,我们经常会额外做一次前向计算来迫使参数被真正地初始化。

    避免延后初始化

    如果系统在调用initialize函数时能够知道所有参数的形状,那么延后初始化就不会发生.
    第一种情况是我们要对已初始化的模型重新初始化时。因为参数形状不会发生变化,所以系统能够立即进行重新初始化。

    #重新初始化
    net.initialize(init=MyInit(), force_reinit=True)
    

    第二种情况是我们在创建层的时候指定了它的输入个数,使系统不需要额外的信息来推测参数形状
    通过in_units来指定每个全连接层的输入个数,使初始化能够在initialize函数被调用时立即发生。

    #实例化
    net = nn.Sequential()
    #添加隐藏层,指定输入个数20,激活函数'relu'
    net.add(nn.Dense(256, in_units=20, activation='relu'))
    #添加输出层,指定输入个数256
    net.add(nn.Dense(10, in_units=256))
    #初始化权重参数
    net.initialize(init=MyInit())
    
  • 相关阅读:
    从浏览器输入URL到页面渲染的过程
    安全分析的几个好的工具网站的使用
    从一次渗透谈到linux如何反弹shell
    python 进行抓包嗅探
    MYSQL的索引和常见函数
    一篇博客搞定redis基础
    新型横向移动工具原理分析、代码分析、优缺点以及检测方案
    Java反序列化漏洞的挖掘、攻击与防御
    关于Memcached反射型DRDoS攻击分析
    spark未授权RCE漏洞
  • 原文地址:https://www.cnblogs.com/strategist-614/p/14411658.html
Copyright © 2011-2022 走看看