zoukankan      html  css  js  c++  java
  • Pytorch学习笔记07----nn.Module类与前向传播函数forward的理解

    1.nn.Module类理解

    pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的

    方法预览:

    class Module(object):
        def __init__(self):
        def forward(self, *input):
     
        def add_module(self, name, module):
        def cuda(self, device=None):
        def cpu(self):
        def __call__(self, *input, **kwargs):
        def parameters(self, recurse=True):
        def named_parameters(self, prefix='', recurse=True):
        def children(self):
        def named_children(self):
        def modules(self):  
        def named_modules(self, memo=None, prefix=''):
        def train(self, mode=True):
        def eval(self):
        def zero_grad(self):
        def __repr__(self):
        def __dir__(self):
    '''
    有一部分没有完全列出来
    '''

    我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__和forward这两个方法。但有一些注意技巧:

    (1)一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;

    (2)一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
        
    (3)forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心
    总结:

    torch.nn是专门为神经网络设计的模块化接口。nn构建于autograd之上,可以用来定义和运行神经网络。
    nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法
    定义自已的网络:
      需要继承nn.Module类,并实现forward方法。
      一般把网络中具有可学习参数的层放在构造函数__init__()中,
      不具有可学习参数的层(如ReLU)可放在构造函数中,也可不放在构造函数中(而在forward中使用nn.functional来代替)
      只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现(利用Autograd)
      在forward函数中可以使用任何Variable支持的函数,毕竟在整个pytorch构建的图中,是Variable在流动。还可以使用if,for,print,log等python语法.
    注:Pytorch基于nn.Module构建的模型中,只支持mini-batch的Variable输入方式

    2.forward()函数自动调用的理解和分析

    最近在使用pytorch的时候,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数

    自动调用 forward 函数原因分析:

    利用Python的语言特性,y = model(x)是调用了对象model的__call__方法,而nn.Module把__call__方法实现为类对象的forward函数,所以任意继承了nn.Module的类对象都可以这样简写来调用forward函数。

    案例:

    class LeNet(nn.Module):
        def __init__(self):
            super(LeNet, self).__init__()
     
        layer1 = nn.Sequential()
        layer1.add_module('conv1', nn.Conv(1, 6, 3, padding=1))
        layer1.add_moudle('pool1', nn.MaxPool2d(2, 2))
        self.layer1 = layer1
     
        layer2 = nn.Sequential()
        layer2.add_module('conv2', nn.Conv(6, 16, 5))
        layer2.add_moudle('pool2', nn.MaxPool2d(2, 2))
        self.layer2 = layer2
     
        layer3 = nn.Sequential()
        layer3.add_module('fc1', nn.Linear(400, 120))
        layer3.add_moudle('fc2', nn.Linear(120, 84))
        layer3.add_moudle('fc3', nn.Linear(84, 10))
        self.layer3 = layer3
        

      def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.view(x.size(0), -1) x = self.layer3(x) return x

    模型调用:

    model = LeNet()
    y = model(x)

    调用forward方法的具体流程是:

    执行y = model(x)时,由于LeNet类继承了Module类,而Module这个基类中定义了__call__方法,所以会执行__call__方法,而__call__方法中调用了forward()方法

    只要定义类型的时候,实现__call__函数,这个类型就成为可调用的。 换句话说,我们可以把这个类型的对象当作函数来使用

    定义__call__方法的类可以当作函数调用(参见:https://www.cnblogs.com/luckyplj/p/13378008.html)

        def __call__(self, *input, **kwargs):
            for hook in self._forward_pre_hooks.values():
                result = hook(self, input)
                if result is not None:
                    if not isinstance(result, tuple):
                        result = (result,)
                    input = result
            if torch._C._get_tracing_state():
                result = self._slow_forward(*input, **kwargs)
            else:
                result = self.forward(*input, **kwargs)
            for hook in self._forward_hooks.values():
                hook_result = hook(self, input, result)
                if hook_result is not None:
                    result = hook_result
            if len(self._backward_hooks) > 0:
                var = result
                while not isinstance(var, torch.Tensor):
                    if isinstance(var, dict):
                        var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                    else:
                        var = var[0]
                grad_fn = var.grad_fn
                if grad_fn is not None:
                    for hook in self._backward_hooks.values():
                        wrapper = functools.partial(hook, self)
                        functools.update_wrapper(wrapper, hook)
                        grad_fn.register_hook(wrapper)
            return result

    总结:当执行model(x)的时候,底层自动调用forward方法计算结果

    参考文献:
    https://blog.csdn.net/u011501388/article/details/84062483

  • 相关阅读:
    Win7专业版系统下事件绑定的Command事件不执行
    Win8系统下报错:无法将字符串“*”转换为Length.
    C#创建job计划用于调用存储过程刷新数据
    with语句
    for in语句与for in语句输入顺序问题
    HighCharts日期及数值格式化
    在web.config文件中,增加“type="APP.Modules.CommandModule,CommandModules"”节点会导致awesome font字体图标显示为方框框
    String、StringBuffer与StringBuilder之间区别
    java使用maven创建springmvc web项目
    手机APP下单支付序列图
  • 原文地址:https://www.cnblogs.com/luckyplj/p/13378293.html
Copyright © 2011-2022 走看看