https://github.com/pytorch/pytorch/issues/42885
import torch import torch.nn as nn class Foo(nn.Conv1d): def forward(self, x): return super().forward(x)
这里return super.forward(x)怎么理解?
返回父类中的forward()方法。
参考:https://stackoverflow.com/questions/54752983/calling-supers-forward-method
import torch class Parent(torch.nn.Module): def forward(self, tensor): return tensor + 1 class Child(Parent): def forward(self, tensor): return super(Child, self).forward(tensor) + 1 module = Child() # Increment output by 1 so we should get `4` module.register_forward_hook(lambda module, input, output: output + 1) print(module(torch.tensor(1))) # and it is 4 indeed print(module.forward(torch.tensor(1))) # here it is 3 still
def increment_by_one(module, input, output): return output + 1 class Parent(torch.nn.Module): def forward(self, tensor): return tensor + 1 class Child(Parent): def forward(self, tensor): # Increment by `1` from Parent super().register_forward_hook(increment_by_one) return super().forward(tensor) + 1 module = Child() # Increment output by 1 so we should get `5` in total module.register_forward_hook(increment_by_one) print(module(torch.tensor(1))) # and it is 5 indeed print(module.forward(torch.tensor(1))) # here is 3
例如DenseNet中出现类似:
定义DenseLayer(这里似乎仅仅定义了网络层,而forward行为则是直接返回super().forward(x))
class DenseLayer(nn.Sequential): def __init__(self, in_channels, growth_rate): super().__init__() self.add_module('norm', nn.BatchNorm1d(in_channels)) self.add_module('relu', nn.ReLU(inplace=True)) self.add_module('conv', nn.Conv1d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)) self.add_module('drop', nn.Dropout1d(p=0.2)) def forward(self, x): return super().forward(x)
通过DenseLayer组装DenseBlock:
class DenseBlock(nn.Module): def __init__(self, in_channels, growth_rate, n_layers): super().__init__() self.layers = nn.ModuleList([DenseLayer(in_channels + i*growth_rate, growth_rate) for i in range(n_layers)]) def forward(self, x): for layer in self.layers: out = layer(x) x = torch.cat([x, out], 1) # 1 = channel axis return x