zoukankan      html  css  js  c++  java
  • resnet18下载与保存,转换为ONNX模型,导出 .wts 格式的权重文件

    1.download and save to 'resnet18.pth' file:

    import torch
    from torch import nn
    from torch.nn import functional as F
    import torchvision
    
    def main():
        print('cuda device count: ', torch.cuda.device_count())
        net = torchvision.models.resnet18(pretrained=True)
        #net.fc = nn.Linear(512, 2)
        net = net.to('cuda:0')
        net.eval()
        print(net)
        tmp = torch.ones(2, 3, 224, 224).to('cuda:0')
        out = net(tmp)
        print('resnet18 out:', out.shape)
        torch.save(net, "resnet18.pth")
    
    if __name__ == '__main__':
        main()

    this  'resnet18.pth' file contains the model structure and weights.

    2.load the .pth file and transform it to ONNX format:

    import torch
    
    def main():
        
        model = torch.load('resnet18.pth')
        # model.eval()
        inputs = torch.randn(1,3,224,224)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        inputs = inputs.to(device)
        torch.onnx.export(model,inputs, 'resnet18_trtpose.onnx',training=2)
        
    if __name__ == '__main__':
        main()

    3.load and read the .pth file, extract the weights of the model to a .wts file

    import torch
    from torch import nn
    import torchvision
    import os
    import struct
    from torchsummary import summary
    
    def main():
        print('cuda device count: ', torch.cuda.device_count())
        net = torch.load('resnet18.pth')
        net = net.to('cuda:0')
        net.eval()
        print('model: ', net)
        #print('state dict: ', net.state_dict().keys())
        tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
        print('input: ', tmp)
        out = net(tmp)
        print('output:', out)
    
        summary(net, (3,224,224))
        #return
        f = open("resnet18.wts", 'w')
        f.write("{}
    ".format(len(net.state_dict().keys())))
        for k,v in net.state_dict().items():
            print('key: ', k)
            print('value: ', v.shape)
            vr = v.reshape(-1).cpu().numpy()
            f.write("{} {}".format(k, len(vr)))
            for vv in vr:
                f.write(" ")
                f.write(struct.pack(">f", float(vv)).hex())
            f.write("
    ")
    
    if __name__ == '__main__':
        main()
  • 相关阅读:
    jQuery获取Select选择的Text和 Value(转)
    android学习---EditText
    android学习---Activity
    android学习---LinearLayout
    android学习---布局Layout
    android颜色码制表
    java面试题二
    java面试题一
    基本排序算法java实现
    Integer与int的区别
  • 原文地址:https://www.cnblogs.com/mrlonely2018/p/15078499.html
Copyright © 2011-2022 走看看