zoukankan      html  css  js  c++  java
  • pytorch nn.Parameters vs nn.Module.register_parameter

    register_parameter

    nn.Parameters 与 register_parameter 都会向 _parameters写入参数,但是后者可以支持字符串命名。
    从源码中可以看到,nn.Parameters为Module添加属性的方式也是通过register_parameter向 _parameters写入参数。

        def __setattr__(self, name, value):
            def remove_from(*dicts):
                for d in dicts:
                    if name in d:
                        del d[name]
    
            params = self.__dict__.get('_parameters')
            if isinstance(value, Parameter):
                if params is None:
                    raise AttributeError(
                        "cannot assign parameters before Module.__init__() call")
                remove_from(self.__dict__, self._buffers, self._modules)
                self.register_parameter(name, value)
            elif params is not None and name in params:
                if value is not None:
                    raise TypeError("cannot assign '{}' as parameter '{}' "
                                    "(torch.nn.Parameter or None expected)"
                                    .format(torch.typename(value), name))
                self.register_parameter(name, value)
            else:
                modules = self.__dict__.get('_modules')
                if isinstance(value, Module):
                    if modules is None:
                        raise AttributeError(
                            "cannot assign module before Module.__init__() call")
                    remove_from(self.__dict__, self._parameters, self._buffers)
                    modules[name] = value
                elif modules is not None and name in modules:
                    if value is not None:
                        raise TypeError("cannot assign '{}' as child module '{}' "
                                        "(torch.nn.Module or None expected)"
                                        .format(torch.typename(value), name))
                    modules[name] = value
                else:
                    buffers = self.__dict__.get('_buffers')
                    if buffers is not None and name in buffers:
                        if value is not None and not isinstance(value, torch.Tensor):
                            raise TypeError("cannot assign '{}' as buffer '{}' "
                                            "(torch.Tensor or None expected)"
                                            .format(torch.typename(value), name))
                        buffers[name] = value
                    else:
                        object.__setattr__(self, name, value)
    
    import torch
    from torch import nn
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            print('before register:
    ', self._parameters, end='
    
    ')
            self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3)))
            print('after register and before nn.Parameter:
    ', self._parameters, end='
    
    ')
    
            self.my_param2 = nn.Parameter(torch.randn(2, 2))
            print('after register and nn.Parameter:
    ', self._parameters, end='
    
    ')
    
        def forward(self, x):
            return x
    
    mymodel = MyModel()
    
    for k, v in mymodel.named_parameters():
        print(k, v)
    

    程序返回为:

    before register:
     OrderedDict()
    
    after register and before nn.Parameter:
     OrderedDict([('my_param1', Parameter containing:
    tensor([[-1.3542, -0.4591, -2.0968],
            [-0.4345, -0.9904, -0.9329],
            [ 1.4990, -1.7540, -0.4479]], requires_grad=True))])
    
    after register and nn.Parameter:
     OrderedDict([('my_param1', Parameter containing:
    tensor([[-1.3542, -0.4591, -2.0968],
            [-0.4345, -0.9904, -0.9329],
            [ 1.4990, -1.7540, -0.4479]], requires_grad=True)), ('my_param2', Parameter containing:
    tensor([[ 1.0205, -1.3145],
            [-1.1108,  0.4288]], requires_grad=True))])
    
    my_param1 Parameter containing:
    tensor([[-1.3542, -0.4591, -2.0968],
            [-0.4345, -0.9904, -0.9329],
            [ 1.4990, -1.7540, -0.4479]], requires_grad=True)
    my_param2 Parameter containing:
    tensor([[ 1.0205, -1.3145],
            [-1.1108,  0.4288]], requires_grad=True)
    
  • 相关阅读:
    python中装饰器使用
    python文件读取操作、序列化
    Xshell使用教程
    Hadoop基础(三):基于Ubuntu16搭建Hadoop运行环境搭建
    UBUNTU的默认root密码是多少,修改root密码
    Ubuntu16.04设置静态ip
    Scala 基础(十六):泛型、类型约束-上界(Upper Bounds)/下界(lower bounds)、视图界定(View bounds)、上下文界定(Context bounds)、协变、逆变和不变
    物联网初探
    电脑不能安装虚拟机--解决办法
    vbox虚拟机和vm虚拟机 虚拟机网络不通的解决方法
  • 原文地址:https://www.cnblogs.com/zi-wang/p/11773623.html
Copyright © 2011-2022 走看看