zoukankan      html  css  js  c++  java
  • register_buffer vs register_parameter

    先来看下nn.Module的成员:

        def __init__(self):
            """
            Initializes internal Module state, shared by both nn.Module and ScriptModule.
            """
            torch._C._log_api_usage_once("python.nn_module")
    
            self.training = True
            self._parameters = OrderedDict()
            self._buffers = OrderedDict()
            self._backward_hooks = OrderedDict()
            self._forward_hooks = OrderedDict()
            self._forward_pre_hooks = OrderedDict()
            self._state_dict_hooks = OrderedDict()
            self._load_state_dict_pre_hooks = OrderedDict()
            self._modules = OrderedDict()
    

    register_bufferregister_parameter只涉及到_buffer_parameters,调用这两个函数分别会向两个成员写入数据。

    _buffer_parameter都会被state_dict返回,且可以通过.cpu().cuda()在设备间进行转换。
    _buffer中的元素不会被优化器更新,如果在模型中需要需要一些参数,并且要通过state_dict返回,且不需要被优化器训练,那么这些参数可以注册在_buffer中。
    例如在maskrcnn_benchmark中的anchor_generator生成中就用到了register_buffer,以及detectron2中的BatchNorm2d

    如果定义self.param1=torch.randn(2,2),那么param1是不会被state_dict返回的,且不会被.cpu().cuda()在设备间进行转换。

    import torch
    from torch import nn
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            print('before register buffer:
    ', self._buffers, end='
    
    ')
            self.register_buffer('mybuffer1', torch.randn(2, 2))
            print('after register buffer:
    ', self._buffers, end='
    
    ')
    
            print('before register parameter:
    ', self._parameters, end='
    
    ')
            self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3)))
            print('after register parameter:
    ', self._parameters, end='
    
    ')
            self.param1 = torch.randn(3, 3)
    
        def forward(self, x):
            return x
    
    mymodel = MyModel()
    mymodel.cuda()
    print(list(mymodel.parameters()))
    print(list(mymodel.buffers()))
    print(mymodel.param1)
    
    

    返回如下

    before register buffer:
     OrderedDict()
    
    after register buffer:
     OrderedDict([('mybuffer1', tensor([[-0.4997, -1.0214],
            [ 0.5604, -2.3252]]))])
    
    before register parameter:
     OrderedDict()
    
    after register parameter:
     OrderedDict([('my_param1', Parameter containing:
    tensor([[ 0.1465,  1.1252, -0.2854],
            [ 2.2109, -0.3919,  0.0385],
            [ 0.3347,  0.1597,  0.7505]], requires_grad=True))])
    
    [Parameter containing:
    tensor([[ 0.1465,  1.1252, -0.2854],
            [ 2.2109, -0.3919,  0.0385],
            [ 0.3347,  0.1597,  0.7505]], device='cuda:0', requires_grad=True)]
    [tensor([[-0.4997, -1.0214],
            [ 0.5604, -2.3252]], device='cuda:0')]
    tensor([[ 0.6994, -2.6078,  2.0409],
            [-0.1210,  1.0048, -1.3913],
            [-1.3752, -1.3748, -2.4478]])
    
  • 相关阅读:
    使用SpringAop 验证方法参数是否合法
    log4jdbc-remix安装配置
    mybatis和spring3.1整合
    MyBatis-Spring 执行SQL语句的流程
    SSH配置log4j的方法
    Drupal 判断匿名用户必须先登录的解决方法
    Drupal 出错的解决办法
    crontab执行PHP
    本地生成Rails API文档
    一个根据身份证号获取的程序
  • 原文地址:https://www.cnblogs.com/zi-wang/p/11773841.html
Copyright © 2011-2022 走看看