zoukankan      html  css  js  c++  java
  • Pytorch中Module,Parameter和Buffer的区别

    下文都将torch.nn简写成nn

    • Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类。
    • Buffer: buffer和parameter相对,就是指那些不需要参与反向传播的参数
      示例如下:
    class MyModel(nn.Module):
    	def __init__(self):
    		super(MyModel, self).__init__()
    		self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
    		self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
    		self.my_param = nn.Parameter(torch.randn(1))
    	def forward(self, x):
    		return x	
    
    model = MyModel()
    print(model.state_dict())
    >>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))])
    
    • Parameter: 是nn.parameter.Paramter,也就是组成Module的参数。例如一个nn.Linear通常由weightbias参数组成。它的特点是默认requires_grad=True,也就是说训练过程中需要反向传播的,就需要使用这个
    import torch.nn as nn
    fc = nn.Linear(2,2)
    
    # 读取参数的方式一
    fc._parameters
    >>> OrderedDict([('weight', Parameter containing:
                  tensor([[0.4142, 0.0424],
                          [0.3940, 0.0796]], requires_grad=True)),
                 ('bias', Parameter containing:
                  tensor([-0.2885,  0.5825], requires_grad=True))])
    			  
    # 读取参数的方式二(推荐这种)
    for n, p in fc.named_parameters():
    	print(n,p)
    >>>weight Parameter containing:
    tensor([[0.4142, 0.0424],
            [0.3940, 0.0796]], requires_grad=True)
    bias Parameter containing:
    tensor([-0.2885,  0.5825], requires_grad=True)
    
    # 读取参数的方式三
    for p in fc.parameters():
    	print(p)
    >>>Parameter containing:
    tensor([[0.4142, 0.0424],
            [0.3940, 0.0796]], requires_grad=True)
    Parameter containing:
    tensor([-0.2885,  0.5825], requires_grad=True)
    

    通过上面的例子可以看到,nn.parameter.Paramterrequires_grad属性值默认为True。另外上面例子给出了三种读取parameter的方法,推荐使用后面两种(这两种的区别可参阅Pytorch: parameters(),children(),modules(),named_*区别),因为是以迭代生成器的方式来读取,第一种方式是一股脑的把参数全丢给你,要是模型很大,估计你的电脑会吃不消。

    另外需要介绍的是_parametersnn.Module__init__()函数中就定义了的一个OrderDict类,这个可以通过看下面给出的部分源码看到,可以看到还初始化了很多其他东西,其实原理都大同小异,你理解了这个之后,其他的也是同样的道理。

    class Module(object):
    	...
        def __init__(self):
            self._backend = thnn_backend
            self._parameters = OrderedDict()
            self._buffers = OrderedDict()
            self._backward_hooks = OrderedDict()
            self._forward_hooks = OrderedDict()
            self._forward_pre_hooks = OrderedDict()
            self._state_dict_hooks = OrderedDict()
            self._load_state_dict_pre_hooks = OrderedDict()
            self._modules = OrderedDict()
            self.training = True
    

    每当我们给一个成员变量定义一个nn.parameter.Paramter的时候,都会自动注册到_parameters,具体的步骤如下:

    import torch.nn as nn
    class MyModel(nn.Module):
    	def __init__(self):
    		super(MyModel, self).__init__()
    		# 下面两种定义方式均可
    		self.p1 = nn.paramter.Paramter(torch.tensor(1.0))
    		print(self._parameters)
    		self.p2 = nn.Paramter(torch.tensor(2.0))
    		print(self._parameters)
    
    • 首先运行super(MyModel, self).__init__(),这样MyModel就初始化了_paramters等一系列的OrderDict,此时所有变量还都是空的。
    • self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 这行代码会触发nn.Module预定义好的__setattr__函数,该函数部分源码如下,:
    def __setattr__(self, name, value):
    	...
    	params = self.__dict__.get('_parameters')
    	if isinstance(value, Parameter):
    		if params is None:
    			raise AttributeError(
    				"cannot assign parameters before Module.__init__() call")
    		remove_from(self.__dict__, self._buffers, self._modules)
    		self.register_parameter(name, value)
    	...
    

    __setattr__函数作用简单理解就是判断你定义的参数是否正确,如果正确就继续调用register_parameter函数进行注册,这个函数简单概括就是做了下面这件事

    def register_parameter(self,name,param):
    	...
    	self._parameters[name]=param
    

    下面我们实例化这个模型看结果怎样

    model = MyModel()
    >>>OrderedDict([('p1', Parameter containing:
    tensor(1., requires_grad=True))])
    OrderedDict([('p1', Parameter containing:
    tensor(1., requires_grad=True)), ('p2', Parameter containing:
    tensor(2., requires_grad=True))])
    

    结果和上面分析的一致。



    MARSGGBO原创


    如有意合作,欢迎私戳

    邮箱:marsggbo@foxmail.com


    2019-12-20 21:11:02



  • 相关阅读:
    新概念英语(1-25)Mrs. Smith's Kitchen
    新概念英语(1-23)Which glasses?
    新概念英语(1-21)Whick book
    BZOJ2212: [Poi2011]Tree Rotations(线段树合并)
    BZOJ4773: 负环(倍增Floyd)
    洛谷P1155 双栈排序(贪心)
    洛谷P1024 一元三次方程求解(数学)
    洛谷P1072 Hankson 的趣味题(数学)
    2018.10.26NOIP模拟赛解题报告
    洛谷P2831 愤怒的小鸟(状压dp)
  • 原文地址:https://www.cnblogs.com/marsggbo/p/12075244.html
Copyright © 2011-2022 走看看