zoukankan      html  css  js  c++  java
  • pytorch nn.Parameters vs nn.Module.register_parameter

    register_parameter

    nn.Parameters 与 register_parameter 都会向 _parameters写入参数,但是后者可以支持字符串命名。
    从源码中可以看到,nn.Parameters为Module添加属性的方式也是通过register_parameter向 _parameters写入参数。

        def __setattr__(self, name, value):
            def remove_from(*dicts):
                for d in dicts:
                    if name in d:
                        del d[name]
    
            params = self.__dict__.get('_parameters')
            if isinstance(value, Parameter):
                if params is None:
                    raise AttributeError(
                        "cannot assign parameters before Module.__init__() call")
                remove_from(self.__dict__, self._buffers, self._modules)
                self.register_parameter(name, value)
            elif params is not None and name in params:
                if value is not None:
                    raise TypeError("cannot assign '{}' as parameter '{}' "
                                    "(torch.nn.Parameter or None expected)"
                                    .format(torch.typename(value), name))
                self.register_parameter(name, value)
            else:
                modules = self.__dict__.get('_modules')
                if isinstance(value, Module):
                    if modules is None:
                        raise AttributeError(
                            "cannot assign module before Module.__init__() call")
                    remove_from(self.__dict__, self._parameters, self._buffers)
                    modules[name] = value
                elif modules is not None and name in modules:
                    if value is not None:
                        raise TypeError("cannot assign '{}' as child module '{}' "
                                        "(torch.nn.Module or None expected)"
                                        .format(torch.typename(value), name))
                    modules[name] = value
                else:
                    buffers = self.__dict__.get('_buffers')
                    if buffers is not None and name in buffers:
                        if value is not None and not isinstance(value, torch.Tensor):
                            raise TypeError("cannot assign '{}' as buffer '{}' "
                                            "(torch.Tensor or None expected)"
                                            .format(torch.typename(value), name))
                        buffers[name] = value
                    else:
                        object.__setattr__(self, name, value)
    
    import torch
    from torch import nn
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            print('before register:
    ', self._parameters, end='
    
    ')
            self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3)))
            print('after register and before nn.Parameter:
    ', self._parameters, end='
    
    ')
    
            self.my_param2 = nn.Parameter(torch.randn(2, 2))
            print('after register and nn.Parameter:
    ', self._parameters, end='
    
    ')
    
        def forward(self, x):
            return x
    
    mymodel = MyModel()
    
    for k, v in mymodel.named_parameters():
        print(k, v)
    

    程序返回为:

    before register:
     OrderedDict()
    
    after register and before nn.Parameter:
     OrderedDict([('my_param1', Parameter containing:
    tensor([[-1.3542, -0.4591, -2.0968],
            [-0.4345, -0.9904, -0.9329],
            [ 1.4990, -1.7540, -0.4479]], requires_grad=True))])
    
    after register and nn.Parameter:
     OrderedDict([('my_param1', Parameter containing:
    tensor([[-1.3542, -0.4591, -2.0968],
            [-0.4345, -0.9904, -0.9329],
            [ 1.4990, -1.7540, -0.4479]], requires_grad=True)), ('my_param2', Parameter containing:
    tensor([[ 1.0205, -1.3145],
            [-1.1108,  0.4288]], requires_grad=True))])
    
    my_param1 Parameter containing:
    tensor([[-1.3542, -0.4591, -2.0968],
            [-0.4345, -0.9904, -0.9329],
            [ 1.4990, -1.7540, -0.4479]], requires_grad=True)
    my_param2 Parameter containing:
    tensor([[ 1.0205, -1.3145],
            [-1.1108,  0.4288]], requires_grad=True)
    
  • 相关阅读:
    CentOS 7.3离线安装 JDK
    七:程序是在何种环境下运行的
    六:亲自尝试压缩数据
    五:内存和磁盘的亲密关系
    四:熟练使用有棱有角的内存
    三:计算机进行小数运算时出错的原因
    二:数据是用二进制数表示的
    一:对程序员来说CPU是什么?
    单元测试的艺术-入门篇
    蔡康永的说话之道2-透过说话,懂得把别放在心上
  • 原文地址:https://www.cnblogs.com/zi-wang/p/11773623.html
Copyright © 2011-2022 走看看