zoukankan      html  css  js  c++  java
  • 『PyTorch』第十四弹_torch.nn.Module类属性

    nn.Module基类的构造函数:

    def __init__(self):
        self._parameters = OrderedDict()
        self._modules = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self.training = True

    其中每个属性的解释如下:

    • _parameters:字典,保存用户直接设置的parameter,self.param1 = nn.Parameter(t.randn(3, 3))会被检测到,在字典中加入一个key为'param',value为对应parameter的item。而self.submodule = nn.Linear(3, 4)中的parameter则不会存于此。
    • _modules:子module,通过self.submodel = nn.Linear(3, 4)指定的子module会保存于此。
    • _buffers:缓存。如batchnorm使用momentum机制,每次前向传播需用到上一次前向传播的结果。
    • _backward_hooks_forward_hooks:钩子技术,用来提取中间变量,类似variable的hook。
    • training:BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。

    上述几个属性中,_parameters_modules_buffers这三个字典中的键值,都可以通过self.key方式获得,效果等价于self._parameters['key'].

    定义一个Module,这个Module即包含自己的Parameters有包含子Module及其Parameters,

    import torch as t
    from torch import nn
    from torch.autograd import Variable as V
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            # 等价与self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3)))
            self.param1 = nn.Parameter(t.rand(3, 3))
            self.submodel1 = nn.Linear(3, 4) 
        def forward(self, input):
            x = self.param1.mm(input)
            x = self.submodel11(x)
            return x
    net = Net()
    

    一、_modules

    # 打印网络对象的话会输出子module结构
    print(net)

    Net(
      (submodel1): Linear(in_features=3, out_features=4)
    )

    # ._modules输出的也是子module结构,不过数据结构和上面的有所不同
    print(net.submodel1)
    print(net._modules) # 字典子类

    Linear(in_features=3, out_features=4)
    OrderedDict([('submodel1', Linear(in_features=3, out_features=4))])

    for name, submodel in net.named_modules():
        print(name, submodel)

     Net(
      (submodel1): Linear(in_features=3, out_features=4)
    )
    submodel1 Linear(in_features=3, out_features=4)

    print(list(net.named_modules())) # named_modules其实是包含了本层的module集合

    [('', Net(
      (submodel1): Linear(in_features=3, out_features=4)
    )), ('submodel1', Linear(in_features=3, out_features=4))]

    二、_parameters

    # ._parameters存储的也是这个结构
    print(net.param1)
    print(net._parameters) # 字典子类,仅仅包含直接定义的nn.Parameters参数

    Parameter containing:
     0.6135  0.8082  0.4519
     0.9052  0.5929  0.2810
     0.6825  0.4437  0.3874
    [torch.FloatTensor of size 3x3]
    
    OrderedDict([('param1', Parameter containing:
     0.6135  0.8082  0.4519
     0.9052  0.5929  0.2810
     0.6825  0.4437  0.3874
    [torch.FloatTensor of size 3x3]
    )])

    for name, param in net.named_parameters():
        print(name, param.size())

    param1 torch.Size([3, 3])
    submodel1.weight torch.Size([4, 3])
    submodel1.bias torch.Size([4])

    三、_buffers

    bn = nn.BatchNorm1d(2)
    input = V(t.rand(3, 2), requires_grad=True)
    output = bn(input)
    bn._buffers
    
    OrderedDict([('running_mean', 
                  1.00000e-02 *
                    9.1559
                    1.9914
                  [torch.FloatTensor of size 2]), ('running_var', 
                   0.9003
                   0.9019
                  [torch.FloatTensor of size 2])])

    四、training

    input = V(t.arange(0, 12).view(3, 4))
    model = nn.Dropout()
    # 在训练阶段,会有一半左右的数被随机置为0
    model(input)
    
    Variable containing:
      0   2   4   0
      8  10   0   0
      0  18   0  22
    [torch.FloatTensor of size 3x4]
    model.training  = False
    # 在测试阶段,dropout什么都不做
    model(input)
    
    Variable containing:
      0   1   2   3
      4   5   6   7
      8   9  10  11
    [torch.FloatTensor of size 3x4]

    Module.train()、Module.eval() 方法和 Module.training属性的关系

    print(net.training, net.submodel1.training)
    net.train() # 将本层及子层的training设定为True
    net.eval() # 将本层及子层的training设定为False
    net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响
    net.training, net.submodel1.training
    
    True True
    (True, False)
  • 相关阅读:
    Java中如何动态创建接口的实现
    使用Spring Cloud Feign作为HTTP客户端调用远程HTTP服务
    MyBatis中的@Mapper注解及配套注解使用详解(上)
    SpringMVC-DispatcherServlet工作流程及web.xml配置
    SpringMVC-DispatcherServlet配置(Spring-servlet.xml)
    Spring MVC 配置文件dispatcher-servlet.xml 文件详解
    python3 UnicodeEncodeError: 'gbk' codec can't encode character 'xa0' in position 30: illegal multibyte sequence
    python3之日期和时间(转载)
    python3 操作sqlSever
    python 抓取alexa数据
  • 原文地址:https://www.cnblogs.com/hellcat/p/8503498.html
Copyright © 2011-2022 走看看