zoukankan      html  css  js  c++  java
  • PyTorch入门:使用PyTorch搭建神经网络LeNet5

    前言

    在本文中,我们基于PyTorch构建一个简单的神经网络LeNet5。

    在阅读本文之前,建议您了解一些卷积神经网络的前置知识,比如卷积、Max Pooling和全连接层等等,可以看我写的相关文章:李宏毅机器学习课程笔记-7.1CNN入门详解

    通过阅读本文,您可以学习到如何使用PyTorch构建神经网络LeNet5。

    模型说明

    在本例中,我们使用如下图所示的神经网络模型:LeNet5。

    img

    该模型有1个输入层、2个卷积层、2次Max Pooling、2个全连接层和1个输出层。

    • 输入层INPUT

      1个channel,图片size是32×32。

    • 卷积层C1

      6个channel,特征图的size是28×28,即每个卷积核的size为(5,5),stride为1。

    • 下采样操作S2

      6个channel,特征图的size是14×14,即Max Pooling窗口size为(2,2)。

    • 卷积层C3

      16个channel,特征图的size是10×10,即每个卷积核的size为(5,5),stride为1。

    • 下采样操作S4

      16个channel,特征图的size是5×5,即Max Pooling窗口size为(2,2)。

    • 全连接层F5

      120个神经元。

    • 全连接层F6

      84个神经元。

    • 输出层OUTPUT

      10个神经元。

    另外,除了输入层和输出层,剩下的卷积层、最大池化操作和全连接层后面都要加上Relu激活函数,下采样操作S4之后需要进行Flatten以和全连接层F5衔接起来。

    代码实现

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class LeNet5(nn.Module):
        def __init__(self):
            super(LeNet5, self).__init__()
            # 卷积层
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
            self.conv3 = nn.Conv2d(6, 16, 5)
            # 全连接层
            self.fc5 = nn.Linear(in_features=16*5*5, out_features=120)
            self.fc6 = nn.Linear(120, 84)
            self.OUTPUT = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, (2, 2)) # Max pooling over a (2, 2) window
            x = F.relu(self.conv3(x))
            x = F.max_pool2d(x, 2) # If the size is a square you can only specify a single number
            x = x.view(-1, 16*5*5) # Flatten
            x = F.relu(self.fc5(x))
            x = F.relu(self.fc6(x))
            x = self.OUTPUT(x)
            return x
    
    net = LeNet5()
    output = net(torch.rand(1, 1, 32, 32))
    # print(output)
    

    参考链接

    https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html

    其实本文内容主要是PyTorch的官方教程。

    PyTorch官方教程中代码实现与图片所示的LeNet5不符(PyTorch官方教程代码中是3×3的卷积核,而图片中LeNet5是5×5的卷积核),本文中我是按照图片所示模型结构实现的。

    其实PyTorch开发者和其他开发者也注意到了这一问题,详见:

    https://github.com/pytorch/tutorials/pull/515

    https://github.com/pytorch/tutorials/commit/630802450c13c78f02f744af1c47d1033b6fe206

    https://github.com/pytorch/tutorials/pull/1257


    Github(github.com):@chouxianyu

    Github Pages(github.io):@臭咸鱼

    知乎(zhihu.com):@臭咸鱼

    博客园(cnblogs.com):@臭咸鱼

    B站(bilibili.com):@绝版臭咸鱼

    微信公众号:@臭咸鱼

    转载请注明出处,欢迎讨论和交流!


  • 相关阅读:
    python 10大算法之一 LinearRegression 笔记
    Android+openCV 动态人脸检测
    ubuntu+github配置使用
    Android+openCV人脸检测2(静态图片)
    Android CameraManager 类
    Android人脸检测1(静态图片)
    Android读写配置2
    Git分支(branch)
    mvn
    git 停止跟踪某一个文件
  • 原文地址:https://www.cnblogs.com/chouxianyu/p/14613460.html
Copyright © 2011-2022 走看看