zoukankan      html  css  js  c++  java
  • Pytorch实现卷积神经网络CNN

    Pytorch是torch的Python版本,对TensorFlow造成很大的冲击,TensorFlow无疑是最流行的,但是Pytorch号称在诸多性能上要优于TensorFlow,比如在RNN的训练上,所以Pytorch也吸引了很多人的关注。之前有一篇关于TensorFlow实现的CNN可以用来做对比。

    下面我们就开始用Pytorch实现CNN。

    step 0 导入需要的包

    1 import torch 
    2 import torch.nn as nn
    3 from torch.autograd import Variable
    4 import torch.utils.data as data
    5 import matplotlib.pyplot as plt

    step 1  数据预处理

    这里需要将training data转化成torch能够使用的DataLoader,这样可以方便使用batch进行训练。

     1 import torchvision  #数据库模块
     2 
     3 torch.manual_seed(1) #reproducible
     4 
     5 #Hyper Parameters
     6 EPOCH = 1
     7 BATCH_SIZE = 50
     8 LR = 0.001
     9 
    10 train_data = torchvision.datasets.MNIST(
    11     root='/mnist/', #保存位置
    12     train=True, #training set
    13     transform=torchvision.transforms.ToTensor(), #converts a PIL.Image or numpy.ndarray 
    14                                         #to torch.FloatTensor(C*H*W) in range(0.0,1.0)
    15     download=True
    16 )
    17 
    18 test_data = torchvision.datasets.MNIST(root='/MNIST/')
    19 #如果是普通的Tensor数据,想使用torch_dataset = data.TensorDataset(data_tensor=x, target_tensor=y)
    20 #将Tensor转换成torch能识别的dataset
    21 #批训练, 50 samples, 1 channel, 28*28, (50, 1, 28 ,28)
    22 train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
    23 
    24 test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
    25 test_y = test_data.test_lables[:2000]

    step 2 定义网络结构

    需要指出的几个地方:1)class CNN需要继承Module ; 2)需要调用父类的构造方法:super(CNN, self).__init__()  ;3)在Pytorch中激活函数Relu也算是一层layer; 4)需要实现forward()方法,用于网络的前向传播,而反向传播只需要调用Variable.backward()即可。

     1 class CNN(nn.Module):
     2     def __init__(self):
     3         super(CNN, self).__init__()
     4         self.conv1 = nn.Sequential( #input shape (1,28,28)
     5             nn.Conv2d(in_channels=1, #input height 
     6                       out_channels=16, #n_filter
     7                      kernel_size=5, #filter size
     8                      stride=1, #filter step
     9                      padding=2 #con2d出来的图片大小不变
    10                      ), #output shape (16,28,28)
    11             nn.ReLU(),
    12             nn.MaxPool2d(kernel_size=2) #2x2采样,output shape (16,14,14)
    13               
    14         )
    15         self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), #output shape (32,7,7)
    16                                   nn.ReLU(),
    17                                   nn.MaxPool2d(2))
    18         self.out = nn.Linear(32*7*7,10)
    19         
    20     def forward(self, x):
    21         x = self.conv1(x)
    22         x = self.conv2(x)
    23         x = x.view(x.size(0), -1) #flat (batch_size, 32*7*7)
    24         output = self.out(x)
    25         return output

    step 3 查看网络结构

    使用print(cnn)可以看到网络的结构详细信息,ReLU()真的是一层layer。

    1 cnn = CNN()
    2 print(cnn)

    step 4 训练

    指定optimizer,loss function,需要特别指出的是记得每次反向传播前都要清空上一次的梯度,optimizer.zero_grad()。

     1 #optimizer
     2 optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
     3 
     4 #loss_fun
     5 loss_func = nn.CrossEntropyLoss()
     6 
     7 #training loop
     8 for epoch in range(EPOCH):
     9     for i, (x, y) in enumerate(train_loader):
    10         batch_x = Variable(x)
    11         batch_y = Variable(y)
    12         #输入训练数据
    13         output = cnn(batch_x)
    14         #计算误差
    15         loss = loss_func(output, batch_y)
    16         #清空上一次梯度
    17         optimizer.zero_grad()
    18         #误差反向传递
    19         loss.backward()
    20         #优化器参数更新
    21         optimizer.step()

    step 5 预测结果

    1 test_output =cnn(test_x[:10])
    2 pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
    3 print(pred_y, 'prediction number')
    4 print(test_y[:10])

    reference:

    莫凡python pytorch 教程

  • 相关阅读:
    poj 2187 Beauty Contest(旋转卡壳)
    poj 2540 Hotter Colder(极角计算半平面交)
    poj 1279 Art Gallery(利用极角计算半平面交)
    poj 3384 Feng Shui(半平面交的联机算法)
    poj 1151 Atlantis(矩形面积并)
    zoj 1659 Mobile Phone Coverage(矩形面积并)
    uva 10213 How Many Pieces of Land (欧拉公式计算多面体)
    uva 190 Circle Through Three Points(三点求外心)
    zoj 1280 Intersecting Lines(两直线交点)
    poj 1041 John's trip(欧拉回路)
  • 原文地址:https://www.cnblogs.com/yangmang/p/7530748.html
Copyright © 2011-2022 走看看