zoukankan      html  css  js  c++  java
  • pytorch 手写数字识别项目 增量式训练

    dataset.py

    '''
    准备数据集
    '''
    import torch
    from torch.utils.data import DataLoader
    from torchvision.datasets import MNIST
    from torchvision.transforms import ToTensor,Compose,Normalize
    import torchvision
    import config
    
    def mnist_dataset(train):
        func = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=(0.1307,),
                std = (0.3081,)
            )
        ])
    
        #准备Mnist数据集
        return MNIST(root="../mnist",train=train,download=False,transform=func)
    
    def get_dataloader(train = True):
        mnist = mnist_dataset(train)
        batch_size = config.train_batch_size if train else config.test_batch_size
        return DataLoader(mnist,batch_size=batch_size,shuffle=True)
    
    if __name__ == '__main__':
        for (images,labels) in get_dataloader():
            print(images.size())
            print(labels)
            break
    

      model.py

    '''定义模型'''
    
    import torch.nn as nn
    import torch.nn.functional as F
    
    class MnistModel(nn.Module):
        def __init__(self):
            super(MnistModel,self).__init__()
            self.fc1 = nn.Linear(28*28,100)
            self.fc2 = nn.Linear(100,10)
    
        def forward(self,image):
            image_viwed = image.view(-1,28*28)
            fc1_out = self.fc1(image_viwed)
            fc1_out_relu = F.relu(fc1_out)
            out = self.fc2(fc1_out_relu)
    
            return F.log_softmax(out,dim=-1)
    

      config.py

    '''
    项目配置
    '''
    import torch
    
    train_batch_size = 128
    test_batch_size = 128
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

      train.py

    '''
    进行模型的训练
    '''
    from dataset import get_dataloader
    from models import MnistModel
    from torch import optim
    import torch.nn.functional as F
    import config
    from tqdm import tqdm
    import numpy as np
    import torch
    import os
    from eval import eval
    
    #实例化模型、优化器、损失函数
    model = MnistModel().to(config.device)
    optimizer = optim.Adam(model.parameters(),lr=0.001)
    
    if os.path.exists("./model/mnist_net.pt"):
        model.load_state_dict(torch.load("./model/mnist_net.pt"))
        optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))
    
    
    #迭代训练
    
    def train(epoch):
        train_dataloader = get_dataloader(train=True)
        bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
        total_loss = []
        for idx,(input,target) in bar:
            input = input.to(config.device)
            target = target.to(config.device)
            #梯度置为0
            optimizer.zero_grad()
            #计算得到预测值
            output = model(input)
            #得到损失
            loss = F.nll_loss(output,target)
            total_loss.append(loss.item())
            #反向传播,计算损失
            loss.backward()
            #参数更新
            optimizer.step()
    
            if idx%10 ==0:
                bar.set_description("epoch:{} idx:{},loss:{}".format(epoch,idx,np.mean(total_loss)))
                torch.save(model.state_dict(),"model/mnist_net.pt")
                torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt")
    
    if __name__ == '__main__':
        for i in range(10):
            train(i)
            eval()
    

      eval.py

    '''
    进行模型的训练
    '''
    from dataset import get_dataloader
    from models import MnistModel
    from torch import optim
    import torch.nn.functional as F
    import config
    import numpy as np
    import torch
    import os
    
    
    
    
    #迭代训练
    
    def eval():
        # 实例化模型、优化器、损失函数
        model = MnistModel().to(config.device)
        optimizer = optim.Adam(model.parameters(), lr=0.01)
    
        if os.path.exists("./model/mnist_net.pt"):
            model.load_state_dict(torch.load("./model/mnist_net.pt"))
            optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))
        test_dataloader = get_dataloader(train=False)
        total_loss = []
        total_acc = []
        with torch.no_grad():
            for input,target in test_dataloader:
                input = input.to(config.device)
                target = target.to(config.device)
                #计算得到预测值
                output = model(input)
                #计算损失
                loss = F.nll_loss(output,target)
                #反向传播,计算损失
                total_loss.append(loss.item())
                #计算准确率
                pred = output.max(dim=-1)[-1]
                total_acc.append(pred.eq(target).float().mean().item())
        print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc)))
    
    if __name__ == '__main__':
            eval()
    

      

    D:anacondapython.exe C:/Users/liuxinyu/Desktop/pytorch_test/day3/手写数字识别/train.py
    epoch:0 idx:460,loss:0.32289110562095413: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s]
    test loss:0.17968503131142147,test acc:0.9453125
    epoch:1 idx:460,loss:0.15012750004513145: 100%|█████████▉| 468/469 [00:20<00:00, 22.10it/s]epoch:1 idx:460,loss:0.15012750004513145: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s]
    test loss:0.12370304338916947,test acc:0.9624208860759493
    epoch:2 idx:460,loss:0.10398845713577534:  99%|█████████▉| 464/469 [00:21<00:00, 22.78it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|█████████▉| 467/469 [00:21<00:00, 22.71it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|██████████| 469/469 [00:21<00:00, 21.82it/s]
    test loss:0.10385569722592077,test acc:0.9697389240506329
    epoch:3 idx:460,loss:0.07973297938720653: 100%|█████████▉| 467/469 [00:22<00:00, 23.12it/s]epoch:3 idx:460,loss:0.07973297938720653: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s]
    test loss:0.08691684670652015,test acc:0.9754746835443038
    epoch:4 idx:460,loss:0.0650228117158285: 100%|█████████▉| 468/469 [00:21<00:00, 24.06it/s]epoch:4 idx:460,loss:0.0650228117158285: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s]
    test loss:0.0803159438309413,test acc:0.9760680379746836
    epoch:5 idx:460,loss:0.05270117848966101: 100%|██████████| 469/469 [00:21<00:00, 21.92it/s]
    test loss:0.08102699166423158,test acc:0.9759691455696202
    epoch:6 idx:460,loss:0.04386751471317642: 100%|██████████| 469/469 [00:19<00:00, 24.58it/s]
    test loss:0.07991968260347089,test acc:0.9769580696202531
    epoch:7 idx:460,loss:0.03656852366544161: 100%|██████████| 469/469 [00:15<00:00, 31.20it/s]
    test loss:0.07767781678917288,test acc:0.9774525316455697
    epoch:8 idx:460,loss:0.03112584312896925: 100%|██████████| 469/469 [00:14<00:00, 32.41it/s]
    test loss:0.07755146227494071,test acc:0.9773536392405063
    epoch:9 idx:460,loss:0.025217091969725495: 100%|██████████| 469/469 [00:14<00:00, 31.53it/s]
    test loss:0.07112929566845863,test acc:0.9802215189873418
    

      接口interface.py

    '''
    进行模型的训练
    '''
    from models import MnistModel
    from torch import optim
    import config
    import torch
    import os
    import cv2
    import torchvision.transforms as transforms
    
    tranform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
                mean=(0.1307,),
                std = (0.3081,)
            )])
    
    # 实例化模型、优化器、损失函数
    model = MnistModel()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    if os.path.exists("./model/mnist_net.pt"):
        model.load_state_dict(torch.load("./model/mnist_net.pt",map_location=lambda storage, loc: storage))
        optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt",map_location=lambda storage, loc: storage))
    
    #预测接口
    def interface(pic_path):
        img = cv2.imread(pic_path)
        img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img = tranform(img_gray)
        # img = np.transpose(img, (2,0,1))
        img = img.unsqueeze(0)
        with torch.no_grad():
            input = img
            #计算得到预测值
            output = model(input)
            pred = output.max(dim=-1)[1]
            print("识别结果为:",pred[0].to("cpu").numpy())
    
    
    if __name__ == '__main__':
        while True:
            path = input("请输入图片地址:")
            path = "./pic_test/"+path+".png"
            print(path)
            interface(path)
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    Qt:移动无边框窗体(使用Windows的SendMessage)
    github atom 试用
    ENode框架Conference案例转载
    技术
    NET 领域驱动设计实战系列总结
    mac 配置Python集成开发环境
    User、Role、Permission数据库设计ABP
    Oracle 树操作
    Oracle 用户权限管理方法
    Web Api 2, Oracle and Entity Framework
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12314982.html
Copyright © 2011-2022 走看看