zoukankan      html  css  js  c++  java
  • opencv 调用 pytorch训练的resnet模型

    使用OpenCV的DNN模块调用pytorch训练的分类模型,这里记录一下中间的流程,主要分为模型训练,模型转换和OpenCV调用三步。

    一、训练二分类模型

    准备二分类数据,直接使用torchvision.models中的resnet18网络,主要编写的地方是自定义数据类中的__getitem__,和网络最后一层。

    • __getitem__
      将同类数据放在不同文件夹下,编写Mydataset类,在__getitem__函数中增加数据增强。
    class Mydataset(Dataset):
        ......
        def __getitem__(self, idx):
            # idx-[0->len(images)]
            img, label = self.images[idx], self.labels[idx]
            tf = transforms.Compose([
                lambda x: Image.open(x).convert('RGB'),
                transforms.Resize((int(self.resize), int(self.resize))),
                # transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
                # transforms.RandomRotation(15),
                # transforms.CenterCrop(self.resize),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
    
            img = tf(img)
            label = torch.tensor(label)
            return img, label
        ......
    
    • 修改网络最后一层
      依据类别,修改最后一层的输出,主要代码如下:
    model = resnet18(pretrained=True)  # 比较好的 model
    model = nn.Sequential(*list(model.children())[:-1],  # [b, 512, 1, 1] -> 接全连接层
                          # [b, 512, 1, 1] -> [b, 512]
                          torch.nn.Flatten(),
                          nn.Linear(512, 2)).to(device)  # 添加全连接层
    
    # x = torch.randn(2, 3, 224, 224)
    # print(model(x).shape)
    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义迭代参数的算法
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    

    二、Pytorch模型转为ONNX模型

    直接调用torch.onnx接口可将模型导出为ONNX格式,这里主要介绍验证导出模型是否正确

    import torch
    from torchvision import transforms
    from PIL import Image
    from torchvision.models import resnet18
    import torch.nn as nn
    import torch.onnx
    import onnx
    import onnxruntime
    import numpy as np
    
    torch_model = "./resnet18-2Class.pkl"
    onnx_save_path = "./resnet18-2Class.onnx"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.randn(1, 3, 224, 224, dtype=torch.float, device=device)
    model = resnet18(pretrained=True)
    model = nn.Sequential(*list(model.children())[:-1],  # [b, 512, 1, 1] -> 接全连接层
                          nn.Flatten(),  # [b, 512, 1, 1] -> [b, 512]
                          nn.Linear(512, 2)).to(device)
    model.load_state_dict(torch.load(torch_model))
    model.eval()
    
    print("Start convert model to onnx...")
    torch.onnx.export(model,
                      data,
                      onnx_save_path,
                      opset_version=10,
                      do_constant_folding=True,  # 是否执行常量折叠优化
                      input_names=["input"],  # 输入名
                      output_names=["output"],  # 输出名
                      dynamic_axes={"input": {0: "batch_size"},  # 批处理变量
                                    "output": {0: "batch_size"}}
    )
    
    print("convert onnx is Done!")
    
    
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    
    
    def get_test_transform():
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((224, 224)),
            # transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    
        return tf
    
    
    img_path = "./1.png"
    img = get_test_transform()(img_path)
    img = img.unsqueeze(0)  # --> NCHW
    print("input img mean {} and std {}".format(img.mean(), img.std()))
    
    torch_out = model(img.to(device))
    print("torch predict: ", torch_out)
    
    # onnx
    resnet_session = onnxruntime.InferenceSession(onnx_save_path)
    inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
    onnx_out = resnet_session.run(None, inputs)[0]
    print("onnx predict: ", onnx_out)
    

    三、OpenCV调用ONNX模型进行分类

    这里主要工作是对数据进行预处理,在第一部分中的__getitem__函数的增强部分,转为openCV图像处理如下,其他直接调用dnn模块下的readNetFromONNX(modelPath)即可。

    cv::Mat img = cv::imread(imgPath);
    img.convertTo(img, CV_32FC3);
    cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
    cv::resize(img, img, cv::Size(224, 224));
    img = img / 255.0;
    std::vector<float> mean_value{ 0.485, 0.456, 0.406 };
    std::vector<float> std_value{ 0.229, 0.224, 0.225 };
    cv::Mat dst;
    std::vector<cv::Mat> rgbChannels(3);
    cv::split(img, rgbChannels);
    for (auto i = 0; i < rgbChannels.size(); i++)
    {
        rgbChannels[i] = (rgbChannels[i] - mean_value[i]) / std_value[i];
    }
    cv::merge(rgbChannels, dst);
    

    其中有一个注意点,就是同一张图片用torchvision.transforms中的Resize()和OpenCV的resize()函数处理的结果会有一点差别,这是因为transforms中默认使用的PIL的resize进行处理,除了默认的双线性插值,还会进行antialiasing,不过这个对于分类任务影响并不太大。

    参考链接

    OpenCV调用Caffe GoogLeNet
    OpenCV自定义算子
    多标签分类

  • 相关阅读:
    学习笔记—查找
    水晶报表图表制作问题
    Chrome对最小字体的限制
    Devexpress的ASPxDateEdit控件设置其‘today’ 为客户端当前日期
    水晶报表多表数据源
    System.Web.HttpValueCollection.ThrowIfMaxHttpCollectionKeysExceeded
    利用水晶报表制作甘特图
    水晶报表打印时最后多打印一空白页
    day3学习
    Python高级自动化培训day1
  • 原文地址:https://www.cnblogs.com/xiaxuexiaoab/p/15219567.html
Copyright © 2011-2022 走看看