zoukankan      html  css  js  c++  java
  • 4. 神经网络的简单搭建、卷积操作与卷积层

      在前3篇博客介绍完pytorch的基础知识之后,我这里我们接着介绍简单网络的搭建,详述卷积操作,最后根据卷积操作搭建 神经网络的卷积层。

    1. nn.Module的简单使用

         官方帮助文档

      首先,我们还是要从帮助文档看起,进入 pytorch 官网,查看 Pytorch 的官方帮助文档


      然后进入 torch.nn 部分(nn 是神经网络 neural network 的简称),查看 container 容器(也可以称之为骨架)

    我们可以看到 torch.nn 下面有很多东西,像是卷积层、池化层、非线性激活、正则化层等等,感兴趣可以提前看一下,后续博客我会有介绍。



      我们主要查看两个部分,torch.nn.Module 的介绍,以及他的 使用:



         代码示例运行

      看完,帮助文档之后,我们进行试验一下(需要用到之前的 tensorboarddatasetstransformDataLoader,如果不会或者是已经忘记需要参考一下我之前写的博客)。
    下面我们的定义了一个 MyModel 类别,继承而来 torch.nn.Module,forward完成的是一个简单的加法操作。

    import cv2
    from PIL import Image
    import torch
    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.data import Dataset
    from torch.utils.tensorboard import SummaryWriter
    from torchvision import transforms
    
    
    class MyModel(torch.nn.Module):
        def __init__(self, delta):
            super(MyModel, self).__init__()
            self.delta = delta
    
        def forward(self, x):
            return x + self.delta
    
    
    if __name__ == "__main__":
        my_model = MyModel(torch.tensor(10))
        x = torch.tensor(5)
        print(my_model(x))
    

    2. 卷积操作(convolution)

      学过数字图像处理或者是相关课程的小伙伴们对卷积操作一定不陌生,倘若不了解的话,可以参考知乎上卷积的提问,个人认为写的非常好!

         torch.nn.functional.conv2d()

      首先,我们查阅官方文档



    根据官方文档,我们写一个例子来实践一下:

    import torch
    import torchvision
    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.data import DataLoader
    import torch.nn.functional as F
    
    
    input = torch.tensor([
        [1, 2, 0, 3, 1],
        [0, 1, 2, 3, 1],
        [1, 2, 1, 0, 0],
        [5, 2, 3, 1, 1],
        [2, 1, 0, 1, 1]
    ])
    kernel = torch.tensor([
        [1, 2, 1],
        [0, 1, 0],
        [2, 1, 0]
    ])
    
    # 为了满足 torch.nn.functional.con2d 输入和卷积核的 size 类型,我们需要对其进行 reshape
    my_list = list(input.shape)
    my_list.insert(0, 1)
    my_list.insert(0, -1)
    input = torch.reshape(input, my_list)   # 直接这样插入。。。
    kernel = torch.reshape(kernel, [1, 1, -1, 3])
    
    result1 = F.conv2d(input=input, weight=kernel, stride=1, padding=0)
    print(result1)
    
    
    result2 = F.conv2d(input=input, weight=kernel, stride=1, padding=1)
    print(result2)
    
    

    3. 神经网络的卷积层

    前两部分,我们了解了 torch.nn.Module,使用该抽象类派生出 Model,并且学习了 torch.nn.functional.conv2d() 卷积操作,下面我们将要学习 torch.nn.Conv2d(),并用此来写一个神经网络的卷积层

         torch.nn.Conv2d()

      下面,我们查看 Torch.nn.Conv2d的官方文档,并对参数进行一些讲解。



         简单的卷积神经网络

    步骤如下所示:

    1. 首先,我们仍旧是使用 datasets 和 Dataloader 进行加载数据集
    2. 使用 torch.nn.Module 派生出一个简单的神经网络
    3. 将网络的运行结果写入到 tensorboard 可视化工具中
    4. 使用 tensorboard 可视化查看结果

    代码:

    import torch
    import torchvision
    from PIL import Image
    import cv2
    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.data import DataLoader
    
    
    class MyModel(torch.nn.Module):
        r""" 
            a class used as neural network modek
        """
        def __init__(self):
            super(MyModel, self).__init__()
            self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=(4, 4), stride=1, padding=1)
            self.conv2 = torch.nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=2, padding=1)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            return x
    
    
    data_path = "../../data_cifar10"
    dataset_test = torchvision.datasets.CIFAR10(data_path, train=False, transform=torchvision.transforms.ToTensor(),                                            download=True)
    data_loader = DataLoader(dataset_test, batch_size=64, shuffle=True, drop_last=False)
    
    log_path = "../../logs"
    writer = SummaryWriter(log_dir=log_path)
    step = 0
    my_model = MyModel()
    step_list = []
    for imgs, targets in data_loader:
        step_list.append(step)
        writer.add_images("original", imgs, step)
        imgs = my_model(imgs)
        writer.add_images("conventional", imgs, step)
        step += 1
    
    writer.close()
    print(step_list)
    

    author:luckylight(xyg)
    date: 2021/11/11
  • 相关阅读:
    junit报错----java.lang.Exception: No tests found matching
    javaweb基础----省市县三级联动(jquery+ajax版)
    javaweb基础----使用jquery的ajax
    javaweb基础----对用户登录密码加密
    javaweb基础----c3p0数据库连接池的相关配置
    javaweb基础----请求转发与重定向的区别
    javaweb基础----eclipse中书写JSP代码自动提示(JavaScipt)
    char、布尔、赋值、算数运算符
    变量常量整型浮点型
    标识符的规范
  • 原文地址:https://www.cnblogs.com/lucky-light/p/15540237.html
Copyright © 2011-2022 走看看