zoukankan      html  css  js  c++  java
  • 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一、背景知识

    python中两个属相相关方法

    result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__getattr__('name')方法,再无则报错

    obj.name = value 会调用builtin函数setattr(obj,'name',value)设置对应属性,如果设置了__setattr__('name',value)方法则优先调用此方法,而非直接将值存入__dict__并新建属性

    二、nn.Module的__setattr__()方法逻辑

    nn.Module中实现了__setattr__()方法,当再class的初始化__init__()中执行module.name=value时,会在其中判断value是否属于Parameters或者nn.Module对象,是则将之存储进入__dict__._parameters和__dict__._modules两个字典中;如果是其他对象诸如Variable、List、dict等等,则调用默认操作,将值直接存入__dict__中。

    示例

    nn.Module的新建Parameter属性,在._parameters中可以查询到,在.__dict__中没有,属于.__dict__._parameters中

    import torch as t
    import torch.nn as nn
    
    module = nn.Module()
    module.param = nn.Parameter(t.ones(2,2))
    
    print(module._parameters)
    
    """
    OrderedDict([('param', Parameter containing:
                   1  1
                   1  1
                  [torch.FloatTensor of size 2x2])])
    """
    
    print(module.__dict__)
    """
    {'_backend': <torch.nn.backends.thnn.THNNFunctionBackend at 0x7f5dbcf8c160>,
     '_backward_hooks': OrderedDict(),
     '_buffers': OrderedDict(),
     '_forward_hooks': OrderedDict(),
     '_forward_pre_hooks': OrderedDict(),
     '_modules': OrderedDict(),
     '_parameters': OrderedDict([('param', Parameter containing:
                                              1  1
                                              1  1
                                             [torch.FloatTensor of size 2x2])]),
     'training': True}
    """
    

    以通常List的格式传入的子Module直接从属于属于.__dict__,并未被_modules识别

    submodule1 = nn.Linear(2,2)
    submodule2 = nn.Linear(2,2)
    module_list = [submodule1,submodule2]
    module.submodules = module_list
    
    print('_modules:',module_list)
    # _modules: [Linear (2 -> 2), Linear (2 -> 2)]
    print('__dict__[submodules]:',module.__dict__.get('submodules'))
    # __dict__[submodules]: [Linear (2 -> 2), Linear (2 -> 2)]
    print('__dict__[submodules]:',module.__dict__['submodules'])
    # __dict__[submodules]: [Linear (2 -> 2), Linear (2 -> 2)]
    

    以ModuleList格式传入的子Module可被._modules识别,而不直接从属于.__dict__

    module_list = nn.ModuleList(module_list)
    module.submodules = module_list
    
    print(isinstance(module_list,nn.Module))
    # True
    
    print(module._modules)
    """
    OrderedDict([('submodules', ModuleList (
      (0): Linear (2 -> 2)
      (1): Linear (2 -> 2)
    ))])
    """
    print(module.__dict__.get('submodules'))
    # None
    print(module.__dict__['submodules'])
    """
    ---------------------------------------------------------------------------
    KeyError                                  Traceback (most recent call last)
    <ipython-input-19-d4344afabcbf> in <module>()
    ----> 1 print(module.__dict__['submodules'])
    
    KeyError: 'submodules'
    """
    

    三、属性查询函数__getattr__相关特性

    nn.Module的.__getattr__()方法会对__dict__._module、__dict__._parameters和__dict__._buffers这三个字典中的key进行查询。当nn.Module进行属性查询时,会先在__dict__进行查询(仅查询本级),查询不到对应属性值时,就会调用.__getattr__()方法,再无结果就报错。

    示例

    对于__dict__中的属性.training,可以看到.__getattr__('training')查询时就没有结果,

    print(module.__dict__.get('submodules'))
    # None
    
    getattr(module,'training')
    # True
    
    module.training
    # True
    
    
    module.__getattr__('training')
    """
    ---------------------------------------------------------------------------
    AttributeError                            Traceback (most recent call last)
    ……
    AttributeError: 'Module' object has no attribute 'training'
    """
    

     另外,我们可以看到.__getattr__可以查询到的结果如下,都是nn.Module自建的属性,

    module.__getattr__
    """
    <bound method Module.__getattr__ of Module (
      (submodules): ModuleList (
        (0): Linear (2 -> 2)
        (1): Linear (2 -> 2)
      )
    )>
    """
    

    对于普通的新建属性,其实和nn.Module自建的没什么不同,不同查询方式输出相似,

    module.attr1 = 2
    getattr(module,'attr1')
    # 2
    
    module.__getattr__('attr1')
    """
    ---------------------------------------------------------------------------
    AttributeError                            Traceback (most recent call last)
    ……
    AttributeError: 'Module' object has no attribute 'attr1'
    """
    

    对于nn.Module的特殊属性,可以看到,getattr和.__getattr__均可查到,这也是由于getattr一次查找无果后,调用.__getattr__的结果,

    getattr(module,'param')
    """
    Parameter containing:
     1  1
     1  1
    [torch.FloatTensor of size 2x2]
    """
    
    module.__getattr__('param')
    """
    Parameter containing:
     1  1
     1  1
    [torch.FloatTensor of size 2x2]
    """
    
  • 相关阅读:
    ECSHOP文章详情页的标题上加个链接
    点击复制代码到粘贴板代码
    ecshop商城用户名和邮箱都能登陆方法
    ECSHOP商品页发表评论时 取消EMAIL必填
    ECSHOP 模板结构说明
    ecshop文章分类页 显视当前文章分类名称及商品分类页显视当前分类名称
    ecshop商城用户名和邮箱都能登陆方法
    Ecshop品牌页如何自定义Title
    常见的颜色搭配、衣裤搭配指南
    ECSHOP首页显示积分商城里的商品
  • 原文地址:https://www.cnblogs.com/hellcat/p/8509351.html
Copyright © 2011-2022 走看看