zoukankan      html  css  js  c++  java
  • 龙良曲pytorch学习笔记_minst

      1 import torch
      2 from torch import nn
      3 from torch.nn import functional as F
      4 from torch import optim
      5 
      6 import torchvision
      7 from matplotlib import pyplot as plt
      8 
      9 # 小工具
     10 
     11 def plot_curve(data):
     12     fig = plt.figure()
     13     plt.plot(range(len(data)),data,color='blue')
     14     plt.legend(['value'],loc='upper right')
     15     plt.xlabel('step')
     16     plt.tlabel('value')
     17     plt.show()
     18 
     19 def plot_image(img,label,name):
     20     fig = plt.figure()
     21     for i in range(6):
     22         plt.subplot(2,3,i+1)
     23         plt,tight_layout()
     24         plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
     25         plt.title("{}:{}".format(name,label[i].item()))
     26         plt.xticks([])
     27         plt.xticks([])
     28         
     29     plt.show()
     30         
     31 def one_hot(label,depth = 10):
     32     out = torch.zeros(label.size(0),depth)
     33     idx = torch.LongTensor(label).view(-1,1)
     34     out.scatter_(dim=1,index=idx,value=1)
     35     return out
     36     
     37 # 一次加载多少图片
     38 batch_size = 512
     39 # step1. load dataset 数据加载
     40 train_loader = torch.utils.data.DataLoader(
     41     torchvision.datasets.MINST('mnist_data',train=True,download=True,
     42                               transform=torchvision.transforms.Compose([
     43                                   torchvision.transfroms.ToTensor(),
     44                                   
     45                                   torchvision.transfroms.Normalize(
     46                                       (0.1307,),(0.3081,))
     47                               ])),
     48     batch_size=batch_size,shuffle=True)
     49 test_loader = torch.utils.data.DataLoader(
     50     torchvision.datasets.MINST('mnist_data/',train=False,download=True,
     51                               transform=torchvision.transforms.Compose([
     52                                   torchvision.transfroms.ToTensor(),
     53                                   torchvision.transfroms.Normalize(
     54                                       (0.1307,),(0.3081,))
     55                               ])),
     56     batch_size=batch_size,shuffle=False)
     57     
     58 # 网络创建
     59 class Net(nn.Module):
     60     
     61     def __init__(self):
     62         super(Net,self).__init__()
     63         
     64     #xw+b
     65     self.fc1 = nn.Linear(28*28,256)
     66     self.fc2 = nn.Linear(256,64)
     67     self.fc3 = nn.Linear(64,10)
     68     
     69     def forward(self,x):
     70         # x:[batch_size,1,28,28]
     71         # h1 = relu(xw1+b1)
     72         x = F.relu(self.fc1(x))
     73         # h1 = relu(h1w2+b2)
     74         x = F.relu(self.fc2(x))
     75         # h3 = h2w3+b3
     76         x = self.fc3(x)
     77         
     78         return x
     79         
     80 net = Net()
     81 # [w1,b1,w2,b1,w3,b3]
     82 optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
     83 
     84 train_loss = []
     85 
     86 # 训练
     87 for epoch in range(3):
     88 
     89     for batch_idx,(x,y) in enumerate(train_loader):
     90     
     91         # x: [b,1,28,28], y:[512]
     92         # [b,1,28,28]-->[b,feature]
     93         x = x.view(x.size(0),28*28)
     94         # --> [b,10]
     95         out = net(x)
     96         # --> [b,10]
     97         y_onehot = one_hot(y)
     98         # loss = mse(out,y_onehot)
     99         loss = F.mse_loss(out,y_onehot)
    100         # 清零梯度
    101         optimizer.zero_grad()
    102         # 计算梯度
    103         loss.backward()
    104         #w' = w - lr*grad 更新梯度
    105         optimizer.step()
    106         
    107         train_loss.append(loss.item())
    108         
    109         if batch_idx % 10 == 0:
    110             print(epoch,batch_idx,loss.item())
    111             
    112 plot_curve(train_loss)
    113             
    114 # 得到一个比较好的    [w1,b1,w2,b1,w3,b3]    
    115 
    116 
    117 # 验证准确率
    118 total_correct = 0
    119 for x,y in test_loader"
    120     x = x.view(x.size(0),28*28)
    121     out = net(x)
    122     # out: [b,10] --> pred: [b]
    123     pred = out.argmax(dim = 1)
    124     correct = pred.eq(y).sum().float().item()
    125     total_correct += correct
    126 
    127 total_num = len(test_loader.dataset)
    128 acc = total_correct / total_num
    129 print('test acc:',acc)
    130 
    131 # 直观显示验证
    132 x,y = next(iter(test_loader))
    133 out = net(x.view(x.size(0),28*28))
    134 pred = out.argmax(dim = 1)
    135 plot_image(x,pred,'test')
    136         
    137         
    138         
    139         
    140         
    141         
  • 相关阅读:
    DateTime的精度小问题
    使用For XML PATH 会影响Cross Apply 返回
    一个update的小故事
    行大小计算测试
    Sql Server 2008R2 遇到了BCP导入各种中文乱码的问题
    php-fpm 启动不了 libiconv.so.2找不到
    Git使用教程
    支付宝接口使用文档说明 支付宝异步通知
    Linux(CentOs6.4)安装Git
    NGINX防御CC攻击教程
  • 原文地址:https://www.cnblogs.com/fxw-learning/p/12292461.html
Copyright © 2011-2022 走看看