zoukankan      html  css  js  c++  java
  • Pytorch中返回super().forward()

     https://github.com/pytorch/pytorch/issues/42885

    import torch
    import torch.nn as nn
    
    
    class Foo(nn.Conv1d):
      def forward(self, x):
        return super().forward(x) 

    这里return super.forward(x)怎么理解?

    返回父类中的forward()方法。

    参考:https://stackoverflow.com/questions/54752983/calling-supers-forward-method

    import torch
    
    
    class Parent(torch.nn.Module):
        def forward(self, tensor):
            return tensor + 1
    
    
    class Child(Parent):
        def forward(self, tensor):
            return super(Child, self).forward(tensor) + 1
    
    
    module = Child()
    # Increment output by 1 so we should get `4`
    module.register_forward_hook(lambda module, input, output: output + 1)
    print(module(torch.tensor(1))) # and it is 4 indeed
    print(module.forward(torch.tensor(1))) # here it is 3 still
    

      

    def increment_by_one(module, input, output):
        return output + 1
    
    
    class Parent(torch.nn.Module):
        def forward(self, tensor):
            return tensor + 1
    
    
    class Child(Parent):
        def forward(self, tensor):
            # Increment by `1` from Parent
            super().register_forward_hook(increment_by_one)
            return super().forward(tensor) + 1
    
    
    module = Child()
    # Increment output by 1 so we should get `5` in total
    module.register_forward_hook(increment_by_one)
    print(module(torch.tensor(1)))  # and it is 5 indeed
    print(module.forward(torch.tensor(1)))  # here is 3
    

      

    例如DenseNet中出现类似:

    定义DenseLayer(这里似乎仅仅定义了网络层,而forward行为则是直接返回super().forward(x))

    class DenseLayer(nn.Sequential):
        def __init__(self, in_channels, growth_rate):
            super().__init__()
            self.add_module('norm', nn.BatchNorm1d(in_channels))
            self.add_module('relu', nn.ReLU(inplace=True))
            self.add_module('conv', nn.Conv1d(in_channels, growth_rate, kernel_size=3,
                                               stride=1, padding=1, bias=False))
            self.add_module('drop', nn.Dropout1d(p=0.2))
    
        def forward(self, x):
    
            return super().forward(x)
    

    通过DenseLayer组装DenseBlock:

    class DenseBlock(nn.Module):
        def __init__(self, in_channels, growth_rate, n_layers):
            super().__init__()
            self.layers = nn.ModuleList([DenseLayer(in_channels + i*growth_rate, growth_rate) for i in range(n_layers)])
    
        def forward(self, x):
            for layer in self.layers:
                out = layer(x)
                x = torch.cat([x, out], 1)  # 1 = channel axis
    
            return x
    

      

    快去成为你想要的样子!
  • 相关阅读:
    Android系统在新进程中启动自定义服务过程(startService)的原理分析
    Thread和Service应用场合的区别
    Android数据格式解析对象JSON用法
    数据交换格式XML和JSON对比
    Android Handler的使用
    Android之Handler用法总结
    Handler的另外一种用法(HandlerThread)
    solr原理
    mysql主从:主键冲突问题
    修改mysql数据库存储目录
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/14347153.html
Copyright © 2011-2022 走看看