zoukankan      html  css  js  c++  java
  • 1.使用RNN做MNIST分类

    第一次用LSTM,从简单做起吧~~

    注意事项:

    • batch_first=True 意味着输入的格式为(batch_size,time_step,input_size),False 意味着输入的格式为(time_step,batch_size,input_size)
    • 取r_out[:,-1,:],即取时间步最后一步的结果,相当于LSTM把一张图片全部扫描完后的返回的状态向量(此时的维度变为(64,64),前面的64是batch_size,后面的64是隐藏层的神经元个数)
     1 import torch
     2 from torch.autograd import Variable
     3 from torchvision import datasets,transforms
     4 #超参数
     5 EPOCH=1
     6 BATCH_SIZE=64
     7 TIME_STEP=28#run time step/image height
     8 INPUT_SIZE=28#run input size/image width
     9 LR=0.01
    10 DOWNLOAD_MNIST=True
    11 
    12 
    13 train_data=datasets.MNIST(root='./mnist',train=True,transform=transforms.ToTensor(),download=DOWNLOAD_MNIST)
    14 train_loader=torch.utils.data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
    15 
    16 test_data=datasets.MNIST(root='./mnist',train=False,transform=transforms.ToTensor(),download=DOWNLOAD_MNIST)
    17 test_loader=torch.utils.data.DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=True)
    18 
    19 class RNN(torch.nn.Module):
    20     def __init__(self):
    21         super(RNN,self).__init__()
    22 
    23         self.rnn=torch.nn.LSTM(
    24             input_size=INPUT_SIZE,
    25             hidden_size=64,
    26             num_layers=1,
    27 
    28             batch_first=True,
    29         )
    30         self.out=torch.nn.Linear(64,10)
    31     def forward(self, x):
    32         r_out,(h_n,h_c)=self.rnn(x,None)#[64,28,64]
    33         out=self.out(r_out[:,-1,:])#[64,10]
    34         return out
    35 
    36 #time_step,batch,input  batch_first=False,
    37 rnn=RNN()
    38 print(rnn)
    39 
    40 optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
    41 loss_func=torch.nn.CrossEntropyLoss()
    42 
    43 for epoch in range(EPOCH):
    44     for step,(x,y) in enumerate(train_loader):
    45         b_x=Variable(x.view(-1,28,28))#reshape x to (batch,time_step.input_size)
    46 
    47         b_y=Variable(y).squeeze()
    48         output=rnn(b_x)
    49         loss=loss_func(output,b_y)
    50         optimizer.zero_grad()
    51         loss.backward()
    52         optimizer.step()
    53 
    54 
    55         if step %50==0:
    56             for test_x,test_y in test_loader:
    57                 test_output=rnn(test_x.view(-1,28,28))
    58                 pred_y=torch.max(test_output,1)[1].data.numpy().squeeze()
    59                 test_y=test_y.numpy()
    60                 acc=sum(pred_y==test_y)/test_y.size
    61                 print(acc)
  • 相关阅读:
    Netty实现原理浅析
    Netty
    JAVA调用Rest服务接口
    泛型约束
    RegisterStartupScript和RegisterClientScriptBlock的用法
    TFS 2010 使用手册(四)备份与恢复
    TFS 2010 使用手册(三)权限管理
    TFS 2010 使用手册(二)项目集合与项目
    TFS 2010 使用手册(一)安装与配置
    错误"Lc.exe 已退出,代码 -1 "
  • 原文地址:https://www.cnblogs.com/tangweijqxx/p/10601394.html
Copyright © 2011-2022 走看看