先来看下nn.Module的成员: def __init__(self): """ Initializes internal Module state, shared by both nn.Module and ScriptModule. """ torch._C._log_api_usage_once("python.nn_module") self.training = True self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict()
register_buffer
和register_parameter
只涉及到_buffer
和_parameters
,调用这两个函数分别会向两个成员写入数据。
_buffer
和_parameter
都会被state_dict
返回,且可以通过.cpu()
和.cuda()
在设备间进行转换。
_buffer
中的元素不会被优化器更新,如果在模型中需要需要一些参数,并且要通过state_dict
返回,且不需要被优化器训练,那么这些参数可以注册在_buffer
中。
例如在maskrcnn_benchmark中的anchor_generator生成中就用到了register_buffer
,以及detectron2中的BatchNorm2d。
如果定义self.param1=torch.randn(2,2)
,那么param1
是不会被state_dict
返回的,且不会被.cpu()
和.cuda()
在设备间进行转换。
import torch from torch import nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() print('before register buffer: ', self._buffers, end=' ') self.register_buffer('mybuffer1', torch.randn(2, 2)) print('after register buffer: ', self._buffers, end=' ') print('before register parameter: ', self._parameters, end=' ') self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3))) print('after register parameter: ', self._parameters, end=' ') self.param1 = torch.randn(3, 3) def forward(self, x): return x mymodel = MyModel() mymodel.cuda() print(list(mymodel.parameters())) print(list(mymodel.buffers())) print(mymodel.param1)
返回如下
before register buffer: OrderedDict() after register buffer: OrderedDict([('mybuffer1', tensor([[-0.4997, -1.0214], [ 0.5604, -2.3252]]))]) before register parameter: OrderedDict() after register parameter: OrderedDict([('my_param1', Parameter containing: tensor([[ 0.1465, 1.1252, -0.2854], [ 2.2109, -0.3919, 0.0385], [ 0.3347, 0.1597, 0.7505]], requires_grad=True))]) [Parameter containing: tensor([[ 0.1465, 1.1252, -0.2854], [ 2.2109, -0.3919, 0.0385], [ 0.3347, 0.1597, 0.7505]], device='cuda:0', requires_grad=True)] [tensor([[-0.4997, -1.0214], [ 0.5604, -2.3252]], device='cuda:0')] tensor([[ 0.6994, -2.6078, 2.0409], [-0.1210, 1.0048, -1.3913], [-1.3752, -1.3748, -2.4478]])