zoukankan      html  css  js  c++  java
  • pytorch实现mnist手写识别

    直接上代码,还有可视化的结果

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets,transforms
    from visdom import Visdom
    from torch.utils.data import DataLoader
    
    batch_size=512
    learning_rate=0.01
    epoches=20
    
    # 5:1:1
    train_db=datasets.MNIST('data',train=True,transform=transforms.Compose([transforms.ToTensor()]))
    train_db,validation_db=torch.utils.data.random_split(train_db,[50000,10000])
    test_db=datasets.MNIST('data',train=False,transform=transforms.Compose([transforms.ToTensor()]))
    train_loader=DataLoader(train_db,batch_size=batch_size,shuffle=True)
    validation_loader=DataLoader(validation_db,batch_size=batch_size,shuffle=True)
    test_loader=DataLoader(test_db,batch_size=batch_size,shuffle=False)
    class Flatten(nn.Module):
        def __init__(self):
            super(Flatten, self).__init__()
        def forward(self,x):
            x=x.view(x.size(0),-1)
            return x
    # 定义网络的结构
    class Mnist(nn.Module):
        def __init__(self):
            super(Mnist, self).__init__()
            self.net=nn.Sequential(
                Flatten(),
                nn.Linear(784,512),
                nn.ReLU(inplace=True),
                nn.Linear(512,256),
                nn.ReLU(inplace=True),
                nn.Linear(256,128),
                nn.ReLU(inplace=True),
                nn.Linear(128,10),
                nn.ReLU(inplace=True)
            )
        def forward(self,x):
            logits=self.net(x)
            return logits
    
    def main():
        mod=Mnist()
        optimizer=optim.SGD(mod.parameters(),lr=learning_rate)
        loss_fun=nn.CrossEntropyLoss()
        vis=Visdom()
        vis.line([0.],[0.],win='train_loss',opts=dict(title='trai_loss'))
        vis.line([0.],[0.],win='accuracy',opts=dict(title='acc'))
        # vis.line([0.],[0.], win='val_loss', opts=dict(title='val_loss'))
        correct=0
        total_num=0
        global_step=0
        for epoch in range(5000):
            for batch_index,(x,y) in enumerate(train_loader):
                # x=x.view(-1,28*28)
                logits=mod(x)
                train_loss=loss_fun(logits,y)
                pred=logits.argmax(dim=1)
                correct+=torch.eq(y,pred).float().sum()
                total_num += x.size(0)
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()
                global_step+=1
                acc=100.*correct/total_num
                vis.line([train_loss.item()],[global_step],win='train_loss',update='append')
                vis.line([acc],[global_step],win='accuracy',update='append')
                print('the loss of {:d} step is {:.3f},the accuracy is {:.3f}%'.format(global_step,train_loss.item(),acc))
    
            mod.eval()
            with torch.no_grad():
                val_correct=0
                val_total=0
                for validation_images,validation_label in validation_loader:
                    # validation_images=validation_images.view(-1,28*28)
                    val_logits=mod(validation_images)
                    pred=val_logits.argmax(dim=1)
                    val_loss=loss_fun(val_logits,validation_label)
                    val_correct+=torch.eq(pred,validation_label).float().sum()
                    val_total+=validation_images.size(0)
                # vis.line([val_loss.item()],[global_step],win='val_loss',update='append')
                vis.images(validation_images.view(-1,1,28,28),win='x')
                vis.text(str(pred.detach().cpu().numpy()), win='pred',
                         opts=dict(title='pred'))
                val_acc=100.* val_correct/val_total
                print('the val acc of {:d} epoch is {:.3f}%'.format(epoch,val_acc))
    
    
    
    if __name__ == '__main__':
        main()

  • 相关阅读:
    HDU1106 排序
    HDU2050 折线分割平面
    HDU2048 神、上帝以及老天爷
    POJ1836 Alignment
    POJ1182 食物链
    HDU2067 小兔的棋盘
    HTML中的ID不能以数字开头
    automake,autoconf使用详解
    How to install Samba server on Ubuntu 12.04
    Netbeans使用UTF-8编码
  • 原文地址:https://www.cnblogs.com/daremosiranaihana/p/12591884.html
Copyright © 2011-2022 走看看