zoukankan      html  css  js  c++  java
  • 【501】pytorch教程之nn.Module类详解

    参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型

      pytorch中对于一般的序列模型,直接使用torch.nn.Sequential类及可以实现,这点类似于keras,但是更多的时候面对复杂的模型,比如:多输入多输出、多分支模型、跨层连接模型、带有自定义层的模型等,就需要自己来定义一个模型了。本文将详细说明如何让使用Mudule类来自定义一个模型。

      pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的。

      我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。但有一些注意技巧:

    • 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;
    • 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
    • forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

      所有放在构造函数__init__里面的层的都是这个模型的“固有属性”。

      官方例子

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Model(nn.Module):
        def __init__(self):
            # 固定内容
            super(Model, self).__init__()
    
            # 定义相关的函数
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
            # 构建模型结构,可以使用F函数内容,其他调用__init__里面的函数
            x = F.relu(self.conv1(x))
    
            # 返回最终的结果
            return F.relu(self.conv2(x))
    

    ☀☀☀<< 举例 >>☀☀☀

      代码一:

    import torch
    
    N, D_in, H, D_out = 64, 1000, 100, 10
     
    torch.manual_seed(1)
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
     
    #-----changed part-----#
    model = torch.nn.Sequential(
        torch.nn.Linear(D_in, H),
        torch.nn.ReLU(),
        torch.nn.Linear(H, D_out),
    )
    #-----changed part-----#
    
    loss_fn = torch.nn.MSELoss(reduction='sum')
    learning_rate = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for t in range(500):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if t % 100 == 99:
            print(t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

      代码二:

    import torch
    
    N, D_in, H, D_out = 64, 1000, 100, 10
    
    torch.manual_seed(1)
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
    
    #-----changed part-----#
    class Alex_nn(nn.Module):
        def __init__(self):
            super(Alex_nn, self).__init__()
            self.h1 = torch.nn.Linear(D_in, H)
            self.h1_relu = torch.nn.ReLU()
            self.output = torch.nn.Linear(H, D_out)
            
        def forward(self, x):
            h1 = self.h1(x)
            h1_relu = self.h1_relu(h1)
            output = self.output(h1_relu)
            return output
            
    model = Alex_nn()
    #-----changed part-----#
    
    loss_fn = torch.nn.MSELoss(reduction='sum')
    learning_rate = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for t in range(500):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if t % 100 == 99:
            print(t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

      代码三:

    import torch
    
    N, D_in, H, D_out = 64, 1000, 100, 10
    
    torch.manual_seed(1)
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
    
    #-----changed part-----#
    class Alex_nn(nn.Module):
        def __init__(self, D_in_, H_, D_out_):
            super(Alex_nn, self).__init__()
            self.D_in = D_in_
            self.H = H_
            self.D_out = D_out_
            
            self.h1 = torch.nn.Linear(self.D_in, self.H)
            self.h1_relu = torch.nn.ReLU()
            self.output = torch.nn.Linear(self.H, self.D_out)
            
        def forward(self, x):
            h1 = self.h1(x)
            h1_relu = self.h1_relu(h1)
            output = self.output(h1_relu)
            return output
            
    model = Alex_nn(D_in, H, D_out)
    #-----changed part-----#
    
    loss_fn = torch.nn.MSELoss(reduction='sum')
    learning_rate = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for t in range(500):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if t % 100 == 99:
            print(t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

      

  • 相关阅读:
    Windows XP下 Android开发环境 搭建
    Android程序的入口点
    在eclipse里 新建android项目时 提示找不到proguard.cfg
    64位WIN7系统 下 搭建Android开发环境
    在eclipse里 新建android项目时 提示找不到proguard.cfg
    This Android SDK requires Android Developer Toolkit version 20.0.0 or above
    This Android SDK requires Android Developer Toolkit version 20.0.0 or above
    Android requires compiler compliance level 5.0 or 6.0. Found '1.4' instead
    Windows XP下 Android开发环境 搭建
    Android程序的入口点
  • 原文地址:https://www.cnblogs.com/alex-bn-lee/p/14092666.html
Copyright © 2011-2022 走看看