zoukankan      html  css  js  c++  java
  • PyTorch 模型构造

    记录几种模型构造的方法:

    继承Module类来构造模型

    Module是所有神经网络模块的基类,通过继承它来得到我们需要的模型,通常我们需要重载Module类的__init__函数和forward函数。

    实例

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))
    

    利用Module的子类

    在Pytorch中实现了继承自Module的可以方便构造模型的类,有SequentialModuleListModuleDict

    • 使用Sequential

      当模型的前向计算为简单串联各个层的计算时,Sequential类可以通过更加简单的方式定义模型。这正是Sequential类的目的:它可以接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算。

      这里实现一个与Sequential具有相似功能的MySequential

      class MySequential(nn.Module):
      from collections import OrderedDict
      def __init__(self, *args):
          super(MySequential, self).__init__()
          if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict
              for key, module in args[0].items():
                  self.add_module(key, module)  # add_module方法会将module添加进self._modules(一个OrderedDict)
          else:  # 传入的是一些Module
              for idx, module in enumerate(args):
                  self.add_module(str(idx), module)
      def forward(self, input):
          # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成员
          for module in self._modules.values():
              input = module(input)
          return input
      
      
    • 使用ModuleList

      将子模块放在一个列表(list)之中
      ModuleList可以像常规的Python list一样执行append()extend()操作,有一些区别在于ModuleList中的所有模块的参数会被自动地添加到整个网络之中

      实例

      net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
      net.append(nn.Linear(256, 10)) # # 类似List的append操作
      print(net[-1])  # 类似List的索引访问
      print(net)
      

      虽然SequentialModuleList都可以列表化构造网络,但二者存在区别:ModuleList仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度 匹配),而且没有实现forward功能(需要自己实现)。Sequential内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部forward功能已经实现。

      ModuleList的出现可以让网络定义前向传播时更加灵活:

      class MyModule(nn.Module):
      def __init__(self):
          super(MyModule, self).__init__()
          self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
      
      def forward(self, x):
          # ModuleList can act as an iterable, or be indexed using ints
          for i, l in enumerate(self.linears):
              x = self.linears[i // 2](x) + l(x)
          return x
      
    • 使用ModuleDict
      ModuleDict接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:

      net = nn.ModuleDict({
      'linear': nn.Linear(784, 256),
      'act': nn.ReLU(),
      })
      net['output'] = nn.Linear(256, 10) # 添加
      print(net['linear']) # 访问
      print(net.output)
      print(net)
      # net(torch.zeros(1, 784)) # 会报NotImplementedError
      

      ModuleList一样,使用ModuleDict时同样需要自己定义forward

    参考:

    1. 动手学深度学习PyTorch版
    2. PyTorch官方文档
  • 相关阅读:
    python 利用条件运算符:学习成绩>=90分用A表示,60-89分之间的用B表示,60分以下的用C表示。
    【原创】jmeter解决接口参数MD5加密的问题
    【原创】python+selenium+ddt+unittest实现web功能自动化测试
    【原创】基于RBI的性能测试理念,通过jmeter快速定位接口最大并发用户数
    【原创】基于pyautogui库进行自动化测试
    【原创】面向对象版本地CPU资源占用监控脚本
    【原创】相对完整的一套以Jmeter作为工具的性能测试教程(接口性能测试,数据库性能测试以及服务器端性能监测)
    【部分原创】python实现视频内的face swap(换脸)
    【原创】python基于大数据现实双色球预测
    【原创】python爬虫获取网站数据并存入本地数据库
  • 原文地址:https://www.cnblogs.com/patrolli/p/11896776.html
Copyright © 2011-2022 走看看