zoukankan      html  css  js  c++  java
  • PyTorch保存、加载模型,PyTorch中已封装的网络模型

    state_dict()函数可以返回所有的状态数据。load_state_dict()函数可以加载这些状态数据。

    推荐使用:

    #保存
    t.save(net.state_dict(),"net.pth")
    #加载
    net2=Net()
    net2.load_state_dict(t.load("net.pth"))

    不推荐直接save与load,因为这种方式严重依赖模型定义方法以及文件路径结构等,容易出问题。

    t.save(net,"net.pth")
    net2=t.load("net.pth")

     【PyTorch中已封装的网络模型】https://pytorch.org/docs/stable/torchvision/index.html

     从上图看出,有针对分类问题、语义分割、目标识别、视频分类的模型。

    以分类模型为例,PyTorch中已封装的模型如下:

     使用方式,参考标黄部分

    ######################################## 1、使用torchvision加载并预处理数据集
    
    #### datasets的ImageFolder读图
    from torchvision.datasets import ImageFolder
    dataset=ImageFolder("E:/data/dogcat_2/train/") #获取路径,返回的是所有图的data、label
    from torchvision import transforms as T #设置格式化条件
    transform=T.Compose([T.Resize((64,64)), 
                         T.ToTensor(), #PIL Image转Tensor,[0,255]自动归一化为[0,1]
                         T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) #标准化,减均值除标准差
                        ])
    dataset=ImageFolder("E:/data/dogcat_2/train/",transform=transform)
    testset=ImageFolder("E:/data/dogcat_2/test/",transform=transform)
    
    #### DataLoader
    from torch.utils.data import DataLoader
    dataloader=DataLoader( dataset,batch_size=4,shuffle=True,num_workers=2 )
    testloader=DataLoader(testset,batch_size=4,shuffle=True,num_workers=2)
    
    #### 显示第1个batch的4幅图(随机)
    from torchvision.transforms import ToPILImage
    from torchvision.utils import make_grid
    dataiter = iter(dataloader)
    (images, labels) = dataiter.next()
    print(labels) #打印标签
    show=ToPILImage() 
    show(make_grid(images*0.5+0.5)).resize((4*64,64)) 
    
    ######################################## 2、定义网络
    from torchvision import models
    net=models.alexnet()
    
    ######################################## 3、定义损失函数和优化器
    import torch.nn as nn
    from torch import optim
    criterion=nn.CrossEntropyLoss() #交叉熵损失函数
    optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #随机梯度下降法,指定要调整的参数和学习率,动量算法加速更新权重
    
    ######################################## 4、训练网络并更新网络参数
    for epoch in range(2):  # 在整个数据集上轮番训练多次,轮训一次叫一个回合(epoch)
    
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            
            # 输入数据
            inputs, labels = data
            
            # 梯度清零
            optimizer.zero_grad()
    
            # forward + backward
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            #更新参数
            optimizer.step()
    
            # 打印一些关于训练的统计信息
            running_loss += loss.item()
            if i % 200 == 199:    # 每 200 个batch打印一次
                print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
    
    print('Finished Training')
    
    ######################################## 5、测试网络
    import torchvision as tv
    import torch as t
    #datasets测试集中前4幅图,并输出标签
    dataiter = iter(testloader)
    (images, labels) = dataiter.next() #返回1个batch(4张图)
    
    # 输出图像和正确的类标签
    #print('实际的label:', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    show(tv.utils.make_grid((images+1)/2)).resize((400,100))
    
    #测试
    outputs = net(images) #预测上边得到的batch(4张图),返回得分(每一类都打分)
    _, predicted = t.max(outputs, 1) #每1张图得分最高的那个类的下标
    
    print(outputs)
    print(predicted)
    #print('预测结果:', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
    show(tv.utils.make_grid((images+1)/2)).resize((400,100))
    
    #测试整个测试集
    correct = 0 #预测正确的图片数
    total = 0 #总共的图片数
    with t.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = t.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print('10000张测试集中的准确率: %d %%' % (100 * correct / total))
  • 相关阅读:
    Hive—数据库级增、删、改、查
    Kafka—Bootstrap broker hadoop102:2181 (id: -1 rack: null) disconnected
    Kafka—命令行操作
    Kafka—Kafka安装部署
    Kafka—Java HotSpot(TM) 64-Bit Server VM warning: INFO: os::commit_memory(0x00000000c0000000, 1073741824, 0) failed; error='Cannot allocate memory' (errno=12)
    mysql—Job for mysqld.service failed because the control process exited with error code. See "systemctl status mysqld.service" and "journalctl -xe" for details.
    Linux—安装MySQL数据库
    Linux—Yum源配置
    Hive-FAILED: SemanticException org.apache.hadoop.hive.ql.metadata.HiveException: java.lang.RuntimeException: Unable to instantiate org.apache.hadoop.hive.ql.metadata.SessionHiveMetaStoreClient
    Linux—Crontab定时任务
  • 原文地址:https://www.cnblogs.com/xixixing/p/12786323.html
Copyright © 2011-2022 走看看