zoukankan      html  css  js  c++  java
  • Pytorch model saving and loading 模型保存和读取

    It is really useful to save and reload the model and its parameters during or after training in deep learning.

    Pytorch provides two methods to do so.

    1. Only restore the parameters (recommended)

    torch.save(the_model.state_dict(), PATH)    # save parameters to PATH
    
    the_model = TheModelClass(*args, **kwargs)    # declare the_model as a object of TheModelClass
    the_model.load_state_dict(torch.load(PATH))    # load parameters from PATH
    

    2. Save all structure and parameters

    torch.save(the_model, PATH)
    
    the_model = torch.load(PATH)
    

    3. Get parameters of certain layer

    params=model.state_dict() 
    for k,v in params.items():
        print(k)    # print the variable names in networks
    print(params['conv1.weight'])   #print conv1's weight
    print(params['conv1.bias'])   #print conv1's bias  
    

      

    reference:http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/

      

  • 相关阅读:
    hbase java api
    hbase
    误删/usr文件夹解决办法
    死锁问题 代码示例
    H2O 生成 多线程并发控制 Semaphore
    多线程 打印零与奇偶数
    多线程交替打印 采用Semaphore
    多线程按顺序执行
    判断回文数
    整数反转
  • 原文地址:https://www.cnblogs.com/beatets/p/8456252.html
Copyright © 2011-2022 走看看