zoukankan      html  css  js  c++  java
  • pytorch-模型保存与加载自己训练的模型详解 

    filename = 'cvae_' + str(epoch+1) + '.pkl'
    save_path = save_dir / Path(filename)
    states = {}
    states['model'] = cvae.state_dict() # 模型参数
    states['z_dim'] = args.z_dim
    states['x_dim'] = args.x_dim
    states['s_dim'] = args.s_dim
    states['optim'] = cvae.state_dict()
    torch.save(states, str(save_path)) #检查点:将states字典存放在save_path文件下 
    
    保存和加载模型的时候,配对的函数:
    对于仅保存state_dict()的方式,那保存和加载模型的方式为:
    保存:torch.save(model.state_dict(), PATH)
    加载:model.laod_state_dict(torch.load(PATH))
    一般加载模型是在训练完成后用模型做测试,这时候加载模型记得要加上model.eval(),把模型切换到evaluation模式,这时候会调整dropout和bactch的模式。

    对于保存和加载整个模型的情况:
    torch.save(model, PATH)
    model = torch.load(PATH)
    可以看到,前面的model.load_state_dict()和这里的不同,前面的情况需要你先定义一个模型,然后再load_state_dict()
    但是这里load整个模型,会把模型的定义一起load进来。完成了模型的定义和加载参数的两个过程。
    详细代码
     1     def save(self):
     2         save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
     3 
     4         if not os.path.exists(save_dir):
     5             os.makedirs(save_dir)
     6 
     7         torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
     8         torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
     9 
    10         with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
    11             pickle.dump(self.train_hist, f)
    12 # 使用方法:对模型初始化以后,使用以下方法,加载模型的参数,以至于不用再对数据集进行训练
    13     def load(self):
    14         save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
    15 
    16         self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
    17         self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))

    note:

    pickle.dump(obj, file, [,protocol]) 序列化对象,将对象obj保存到文件file中去。self.train_hist用于存放模型文件

    pickle.load(file) 反序列化对象,将文件中的数据解析为一个python对象。file中有read()接口和readline()接口



  • 相关阅读:
    day14_oracle数据库备份
    day13_存储过程小记
    day13_先沃联盟定时任务
    day13_自动抽取数据——监控存储过程
    [笔记]《HTTP权威指南》- 实体和编码
    [笔记]《白帽子讲Web安全》- Web框架安全
    [笔记]《Vue移动开发实战技巧》- Vue-router使用
    WPF与Win32互操作
    [翻译]HTML5
    学习资料收藏
  • 原文地址:https://www.cnblogs.com/shuangcao/p/12492706.html
Copyright © 2011-2022 走看看