zoukankan      html  css  js  c++  java
  • pytorch实践:MNIST数字识别(转)

    手写数字识别是深度学习界的“HELLO WPRLD”。网上代码很多,找一份自己读懂,对整个学习网络理解会有帮助。不必多说,直接贴代码吧(代码是网上找的,时间稍久,来处不可考,侵删)

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torchvision import datasets, transforms
    from torch.autograd import Variable
    
    # Training settings
    batch_size = 64
    
    # MNIST Dataset
    # MNIST数据集已经集成在pytorch datasets中,可以直接调用

    train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor()) # Data Loader (Input Pipeline) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 输入1通道,输出10通道,kernel 5*5 self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, 5) self.conv3 = nn.Conv2d(20, 40, 3) self.mp = nn.MaxPool2d(2) # fully connect self.fc = nn.Linear(40, 10)#(in_features, out_features) def forward(self, x): # in_size = 64 in_size = x.size(0) # one batch 此时的x是包含batchsize维度为4的tensor,即(batchsize,channels,x,y),x.size(0)指batchsize的值 把batchsize的值作为网络的in_size # x: 64*1*28*28 x = F.relu(self.mp(self.conv1(x))) # x: 64*10*12*12 feature map =[(28-4)/2]^2=12*12 x = F.relu(self.mp(self.conv2(x))) # x: 64*20*4*4 x = F.relu(self.mp(self.conv3(x))) x = x.view(in_size, -1) # flatten the tensor 相当于resharp # print(x.size()) # x: 64*320 x = self.fc(x) # x:64*10 # print(x.size()) return F.log_softmax(x) #64*10 model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) def train(epoch): for batch_idx, (data, target) in enumerate(train_loader):#batch_idx是enumerate()函数自带的索引,从0开始 # data.size():[64, 1, 28, 28] # target.size():[64] output = model(data) #output:64*10 loss = F.nll_loss(output, target) if batch_idx % 200 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data[0])) optimizer.zero_grad() # 所有参数的梯度清零 loss.backward() #即反向传播求梯度 optimizer.step() #调用optimizer进行梯度下降更新参数 def test(): test_loss = 0 correct = 0 for data, target in test_loader: data, target = Variable(data, volatile=True), Variable(target) output = model(data) # sum up batch loss test_loss += F.nll_loss(output, target, size_average=False).data[0] # get the index of the max log-probability pred = output.data.max(1, keepdim=True)[1] print(pred) correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset) print(' Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) '.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) for epoch in range(1, 10): train(epoch) test()
  • 相关阅读:
    Spring源码情操陶冶-AbstractApplicationContext#obtainFreshBeanFactory
    Spring源码情操陶冶-AbstractApplicationContext#prepareRefresh
    Spring源码情操陶冶-AbstractApplicationContext
    Spring源码情操陶冶-ContextLoader
    Spring源码情操陶冶-ContextLoaderListener
    Spring mybatis源码篇章-MapperScannerConfigurer
    Spring mybatis源码篇章-动态SQL节点源码深入
    Spring mybatis源码篇章-动态SQL基础语法以及原理
    Spring mybatis源码篇章-Mybatis的XML文件加载
    Spring mybatis源码篇章-Mybatis主文件加载
  • 原文地址:https://www.cnblogs.com/jiangnanyanyuchen/p/9782223.html
Copyright © 2011-2022 走看看