zoukankan      html  css  js  c++  java
  • Pytorch自定义加载预训练权重

    pytorch保存模型权重非常方便

    保存模型可以分为两种

    一种是保存整个网络(网络结构+权重参数)

    torch.save(model, 'net.pth.tar')

    一种是只保存模型的权重参数(速度快,内存占用少)

    torch.save(model.state_dict(), 'net.pth.tar')

    标准的加载还可以做一些设置

    如果加载模型只是为了进行推理测试,则将每一层的 requires_grad 置为 False,即固定这些权重参数;还需要调用 model.eval() 将模型置为测试模式,主要是将 dropout 和 batch normalization 层进行固定,否则模型的预测结果每次都会不同。

    如果希望继续训练,则调用 model.train(),以确保网络模型处于训练模式。

    然后,想使用预训练权重有非常严格的要求,要求每一层一模一样,命名都要一样(不然dict的key就不一样了

    如果两个模型实际是一样,既然是字典,可不可以手动赋值呢?

    比如,我把cnn命名成conv了,导致load失败,因此我们来手动赋值

    虽然load成功了,但是实际效果有点问题,直接用于生成,得到的图像颜色不对。

    按道理两种方法的权重应该是一模一样啊,

    前两个是手动赋值,颜色都有偏差,第三个直接load的正常。。。

    难道model.state_dict() 没有包含所有的权重信息?

    其实还有一个简便方法:使用strict=False 参数

    model.load_state_dict(checkpoint["state_dict"], strict=False)
    optimizer.load_state_dict(checkpoint["optimizer"])

    奇怪的是,model的load_state_dict有strict参数,optimizer没有

    参考链接:https://zhuanlan.zhihu.com/p/73893187

    个性签名:时间会解决一切
  • 相关阅读:
    分布式事务的解决方案
    普通平衡树(bzoj 3224)
    [学习笔记] 树链剖分
    矩阵树定理——矩阵树不是树
    哈夫曼树
    SDOI2018一轮NOI培训 题目整理
    Luogu P1119 灾后重建
    轻量级ORM框架——第二篇:Dapper中的一些复杂操作和inner join应该注意的坑(转)
    单点登录的设计与实现
    PHP如何进阶,提升自己
  • 原文地址:https://www.cnblogs.com/lfri/p/14866849.html
Copyright © 2011-2022 走看看