zoukankan      html  css  js  c++  java
  • register_buffer vs register_parameter

    先来看下nn.Module的成员:

        def __init__(self):
            """
            Initializes internal Module state, shared by both nn.Module and ScriptModule.
            """
            torch._C._log_api_usage_once("python.nn_module")
    
            self.training = True
            self._parameters = OrderedDict()
            self._buffers = OrderedDict()
            self._backward_hooks = OrderedDict()
            self._forward_hooks = OrderedDict()
            self._forward_pre_hooks = OrderedDict()
            self._state_dict_hooks = OrderedDict()
            self._load_state_dict_pre_hooks = OrderedDict()
            self._modules = OrderedDict()
    

    register_bufferregister_parameter只涉及到_buffer_parameters,调用这两个函数分别会向两个成员写入数据。

    _buffer_parameter都会被state_dict返回,且可以通过.cpu().cuda()在设备间进行转换。
    _buffer中的元素不会被优化器更新,如果在模型中需要需要一些参数,并且要通过state_dict返回,且不需要被优化器训练,那么这些参数可以注册在_buffer中。
    例如在maskrcnn_benchmark中的anchor_generator生成中就用到了register_buffer,以及detectron2中的BatchNorm2d

    如果定义self.param1=torch.randn(2,2),那么param1是不会被state_dict返回的,且不会被.cpu().cuda()在设备间进行转换。

    import torch
    from torch import nn
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            print('before register buffer:
    ', self._buffers, end='
    
    ')
            self.register_buffer('mybuffer1', torch.randn(2, 2))
            print('after register buffer:
    ', self._buffers, end='
    
    ')
    
            print('before register parameter:
    ', self._parameters, end='
    
    ')
            self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3)))
            print('after register parameter:
    ', self._parameters, end='
    
    ')
            self.param1 = torch.randn(3, 3)
    
        def forward(self, x):
            return x
    
    mymodel = MyModel()
    mymodel.cuda()
    print(list(mymodel.parameters()))
    print(list(mymodel.buffers()))
    print(mymodel.param1)
    
    

    返回如下

    before register buffer:
     OrderedDict()
    
    after register buffer:
     OrderedDict([('mybuffer1', tensor([[-0.4997, -1.0214],
            [ 0.5604, -2.3252]]))])
    
    before register parameter:
     OrderedDict()
    
    after register parameter:
     OrderedDict([('my_param1', Parameter containing:
    tensor([[ 0.1465,  1.1252, -0.2854],
            [ 2.2109, -0.3919,  0.0385],
            [ 0.3347,  0.1597,  0.7505]], requires_grad=True))])
    
    [Parameter containing:
    tensor([[ 0.1465,  1.1252, -0.2854],
            [ 2.2109, -0.3919,  0.0385],
            [ 0.3347,  0.1597,  0.7505]], device='cuda:0', requires_grad=True)]
    [tensor([[-0.4997, -1.0214],
            [ 0.5604, -2.3252]], device='cuda:0')]
    tensor([[ 0.6994, -2.6078,  2.0409],
            [-0.1210,  1.0048, -1.3913],
            [-1.3752, -1.3748, -2.4478]])
    
  • 相关阅读:
    MSSQL大量数据时,建立索引或添加字段后保存更改超时该这么办
    POJ 3261 Milk Patterns (后缀数组)
    POJ 1743 Musical Theme (后缀数组)
    HDU 1496 Equations (HASH)
    694. Distinct Substrings (后缀数组)
    POJ 1222 EXTENDED LIGHTS OUT (枚举 或者 高斯消元)
    POJ 1681· Painter's Problem (位压缩 或 高斯消元)
    POJ 1054 The Troublesome Frog (hash散列)
    HDU 1716 排列2
    HDU 4405 Aeroplane chess (概率DP & 期望)
  • 原文地址:https://www.cnblogs.com/zi-wang/p/11773841.html
Copyright © 2011-2022 走看看