zoukankan      html  css  js  c++  java
  • pytorch(二十七):模型的保存与加载

    pytorch保存模型非常简单,主要有两种方法:

    1. 只保存参数;(官方推荐)
    2. 保存整个模型 (结构+参数)。
      由于保存整个模型将耗费大量的存储,故官方推荐只保存参数,然后在建好模型的基础上加载。本文介绍两种方法,但只就第一种方法进行举例详解。

    一、只保存参数

    1.保存

    一般地,采用一条语句即可保存参数:

    torch.save(model.state_dict(), path)
    

    其中model指定义的模型实例变量,如 model=vgg16( ), path是保存参数的路径,如 path='./model.pth' , path='./model.tar', path='./model.pkl', 保存参数的文件一定要有后缀扩展名。

    特别地,如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:

    state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
    torch.save(state, path)
    
    2.加载

    针对上述第一种情况,也只需要一句即可加载模型:

    model.load_state_dict(torch.load(path))
    

    针对上述第二种以字典形式保存的方法,加载方式如下:

    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint(['epoch'])
    

    需要注意的是,只保存参数的方法在加载的时候要事先定义好跟原模型一致的模型,并在该模型的实例对象(假设名为model)上进行加载,即在使用上述加载语句前已经有定义了一个和原模型一样的Net, 并且进行了实例化 model=Net( ) 。

    另外,如果每一个epoch或每n个epoch都要保存一次参数,可设置不同的path,如 path='./model' + str(epoch) +'.pth',这样,不同epoch的参数就能保存在不同的文件中,选择保存识别率最大的模型参数也一样,只需在保存模型语句前加个if判断语句即可。

    下面给出一个具体的例子程序,该程序只保存最新的参数:

    #-*- coding:utf-8 -*-
    
    '''本文件用于举例说明pytorch保存和加载文件的方法'''
    
    __author__ = 'puxitong from UESTC'
    
    
    import torch as torch
    import torchvision as tv
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import torchvision.transforms as transforms
    from torchvision.transforms import ToPILImage
    import torch.backends.cudnn as cudnn
    import datetime
    import argparse
    
    # 参数声明
    batch_size = 32
    epochs = 10
    WORKERS = 0   # dataloder线程数
    test_flag = True  #测试标志,True时加载保存好的模型进行测试 
    ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
    log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
    
    # 加载MNIST数据集
    transform = tv.transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    
    train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
    test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
    
    train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
    test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
    
    
    # 构造模型
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
            self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
            self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
            self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(256 * 8 * 8, 1024)
            self.fc2 = nn.Linear(1024, 256)
            self.fc3 = nn.Linear(256, 10)
        
        
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = self.pool(F.relu(self.conv2(x)))
            x = F.relu(self.conv3(x))
            x = self.pool(F.relu(self.conv4(x)))
            x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    
    model = Net().cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    
    # 模型训练
    def train(model, train_loader, epoch):
        model.train()
        train_loss = 0
        for i, data in enumerate(train_loader, 0):
            x, y = data
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()
            y_hat = model(x)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss += loss
        loss_mean = train_loss / (i+1)
        print('Train Epoch: {}	 Loss: {:.6f}'.format(epoch, loss_mean.item()))
    
    # 模型测试
    def test(model, test_loader):
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for i, data in enumerate(test_loader, 0):
                x, y = data
                x = x.cuda()
                y = y.cuda()
                optimizer.zero_grad()
                y_hat = model(x)
                test_loss += criterion(y_hat, y).item()
                pred = y_hat.max(1, keepdim=True)[1]
                correct += pred.eq(y.view_as(pred)).sum().item()
            test_loss /= (i+1)
            print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
                test_loss, correct, len(test_data), 100. * correct / len(test_data)))
    
    
    def main():
    
        # 如果test_flag=True,则加载已保存的模型
        if test_flag:
            # 加载保存的模型直接进行测试机验证,不进行此模块以后的步骤
            checkpoint = torch.load(log_dir)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            epochs = checkpoint['epoch']
            test(model, test_load)
            return
    
        for epoch in range(0, epochs):
            train(model, train_load, epoch)
            test(model, test_load)
            # 保存模型
            state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
            torch.save(state, log_dir)
    
    if __name__ == '__main__':
        main()
    
    3.在加载的模型基础上继续训练

    在训练模型的时候可能会因为一些问题导致程序中断,或者常常需要观察训练情况的变化来更改学习率等参数,这时候就需要加载中断前保存的模型,并在此基础上继续训练,这时候只需要对上例中的 main() 函数做相应的修改即可,修改后的 main() 函数如下:

    def main():
    
        # 如果test_flag=True,则加载已保存的模型
        if test_flag:
            # 加载保存的模型直接进行测试机验证,不进行此模块以后的步骤
            checkpoint = torch.load(log_dir)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            test(model, test_load)
            return
    
        # 如果有保存的模型,则加载模型,并在其基础上继续训练
        if os.path.exists(log_dir):
            checkpoint = torch.load(log_dir)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
            print('加载 epoch {} 成功!'.format(start_epoch))
        else:
            start_epoch = 0
            print('无保存模型,将从头开始训练!')
    
        for epoch in range(start_epoch+1, epochs):
            train(model, train_load, epoch)
            test(model, test_load)
            # 保存模型
            state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
            torch.save(state, log_dir)
    

    以上方法,如果想在命令行进行操作执行,都只需加入argpase模块参数即可,相关方法可参考我的博客

    二、保存整个模型

    1.保存
    torch.save(model, path)
    
    2.加载
    model = torch.load(path)
    

    用法可参照上例。



    作者:西北小生_
    链接:https://www.jianshu.com/p/1cd6333128a1
    来源:简书
    著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
  • 相关阅读:
    MATLAB 高斯金字塔
    MATLAB 灰度图直方图均衡化
    MATLAB 生成高斯图像
    MATLAB 灰度、二值图像腐蚀膨胀
    MATLAB 中值滤波
    MATLAB 最大中值滤波
    MATLAB 最大均值滤波
    MATLAB 图像加噪,各种滤波
    MATLAB 图像傅里叶变换,幅度谱,相位谱
    十款最佳人工智能软件
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15151089.html
Copyright © 2011-2022 走看看