zoukankan      html  css  js  c++  java
  • Pytorch实现卷积神经网络CNN

    Pytorch是torch的Python版本,对TensorFlow造成很大的冲击,TensorFlow无疑是最流行的,但是Pytorch号称在诸多性能上要优于TensorFlow,比如在RNN的训练上,所以Pytorch也吸引了很多人的关注。之前有一篇关于TensorFlow实现的CNN可以用来做对比。

    下面我们就开始用Pytorch实现CNN。

    step 0 导入需要的包

    1 import torch 
    2 import torch.nn as nn
    3 from torch.autograd import Variable
    4 import torch.utils.data as data
    5 import matplotlib.pyplot as plt

    step 1  数据预处理

    这里需要将training data转化成torch能够使用的DataLoader,这样可以方便使用batch进行训练。

     1 import torchvision  #数据库模块
     2 
     3 torch.manual_seed(1) #reproducible
     4 
     5 #Hyper Parameters
     6 EPOCH = 1
     7 BATCH_SIZE = 50
     8 LR = 0.001
     9 
    10 train_data = torchvision.datasets.MNIST(
    11     root='/mnist/', #保存位置
    12     train=True, #training set
    13     transform=torchvision.transforms.ToTensor(), #converts a PIL.Image or numpy.ndarray 
    14                                         #to torch.FloatTensor(C*H*W) in range(0.0,1.0)
    15     download=True
    16 )
    17 
    18 test_data = torchvision.datasets.MNIST(root='/MNIST/')
    19 #如果是普通的Tensor数据,想使用torch_dataset = data.TensorDataset(data_tensor=x, target_tensor=y)
    20 #将Tensor转换成torch能识别的dataset
    21 #批训练, 50 samples, 1 channel, 28*28, (50, 1, 28 ,28)
    22 train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
    23 
    24 test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
    25 test_y = test_data.test_lables[:2000]

    step 2 定义网络结构

    需要指出的几个地方:1)class CNN需要继承Module ; 2)需要调用父类的构造方法:super(CNN, self).__init__()  ;3)在Pytorch中激活函数Relu也算是一层layer; 4)需要实现forward()方法,用于网络的前向传播,而反向传播只需要调用Variable.backward()即可。

     1 class CNN(nn.Module):
     2     def __init__(self):
     3         super(CNN, self).__init__()
     4         self.conv1 = nn.Sequential( #input shape (1,28,28)
     5             nn.Conv2d(in_channels=1, #input height 
     6                       out_channels=16, #n_filter
     7                      kernel_size=5, #filter size
     8                      stride=1, #filter step
     9                      padding=2 #con2d出来的图片大小不变
    10                      ), #output shape (16,28,28)
    11             nn.ReLU(),
    12             nn.MaxPool2d(kernel_size=2) #2x2采样,output shape (16,14,14)
    13               
    14         )
    15         self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), #output shape (32,7,7)
    16                                   nn.ReLU(),
    17                                   nn.MaxPool2d(2))
    18         self.out = nn.Linear(32*7*7,10)
    19         
    20     def forward(self, x):
    21         x = self.conv1(x)
    22         x = self.conv2(x)
    23         x = x.view(x.size(0), -1) #flat (batch_size, 32*7*7)
    24         output = self.out(x)
    25         return output

    step 3 查看网络结构

    使用print(cnn)可以看到网络的结构详细信息,ReLU()真的是一层layer。

    1 cnn = CNN()
    2 print(cnn)

    step 4 训练

    指定optimizer,loss function,需要特别指出的是记得每次反向传播前都要清空上一次的梯度,optimizer.zero_grad()。

     1 #optimizer
     2 optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
     3 
     4 #loss_fun
     5 loss_func = nn.CrossEntropyLoss()
     6 
     7 #training loop
     8 for epoch in range(EPOCH):
     9     for i, (x, y) in enumerate(train_loader):
    10         batch_x = Variable(x)
    11         batch_y = Variable(y)
    12         #输入训练数据
    13         output = cnn(batch_x)
    14         #计算误差
    15         loss = loss_func(output, batch_y)
    16         #清空上一次梯度
    17         optimizer.zero_grad()
    18         #误差反向传递
    19         loss.backward()
    20         #优化器参数更新
    21         optimizer.step()

    step 5 预测结果

    1 test_output =cnn(test_x[:10])
    2 pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
    3 print(pred_y, 'prediction number')
    4 print(test_y[:10])

    reference:

    莫凡python pytorch 教程

  • 相关阅读:
    各国语言缩写列表,各国语言缩写-各国语言简称,世界各国域名缩写
    How to see log files in MySQL?
    git 设置和取消代理
    使用本地下载和管理的免费 Windows 10 虚拟机测试 IE11 和旧版 Microsoft Edge
    在Microsoft SQL SERVER Management Studio下如何完整输出NVARCHAR(MAX)字段或变量的内容
    windows 10 x64系统下在vmware workstation pro 15安装macOS 10.15 Catelina, 并设置分辨率为3840x2160
    在Windows 10系统下将Git项目签出到磁盘分区根目录的方法
    群晖NAS(Synology NAS)环境下安装GitLab, 并在Windows 10环境下使用Git
    使用V-2ray和V-2rayN搭建本地代理服务器供局域网用户连接
    windows 10 专业版安装VMware虚拟机碰到的坑
  • 原文地址:https://www.cnblogs.com/yangmang/p/7530748.html
Copyright © 2011-2022 走看看