zoukankan      html  css  js  c++  java
  • 深度网络学习-PyTorch_应用模型进行推断-单张图片的预测

    说明

    没有联网,先把模型下载下来
    先学习怎么推断,
      然后再看怎么进行Dataset Dataloader transform
     接着看怎么训练和评价
    

    软件和硬件

    cuda
      查看cuda 版本
      whereis nvcc
      /usr/local/cuda-10.0/bin/nvcc -V
      cat /usr/local/cuda/version.txt   
    libcudnn.so最终链接的文件名,文件名中包含版本号
    GPU查看
      lspci | grep -i nvidia 
      nvidia-sm
      watch -n 1 nvidia-sm
    

    示例代码

    import torch
    import torch.cuda
    import torch.nn
    import torchvision.models as models
    import torchvision.transforms as transforms
    import numpy as np
    import cv2
    
    def get_model():
        # 加载模型 model_ft = torchvision.models.vgg16(pretrained=False)
        model_ft = models.resnet101(pretrained=False)
        #model_path ="./models/vgg16-397923af.pth"
        model_path ="./models/resnet101-5d3b4d8f.pth"
        pre = torch.load(model_path)
        model_ft.load_state_dict(pre)
        model_ft.cuda()
        return model_ft
    
        # # 查看模型结构
        # print(model_ft)
        # # 查看网络参数
        # for name, parameters in model_ft.named_parameters():
        #     print(name, ':', parameters.size())
        # # 网络模型的卷积方式以及权重数值
        # print("#############-parameters")
        # for child in model_ft.children():
        #     print(child) 
        #     # for param in child.parameters():
        #     #     print(param)
    
    def deal_img(img_path):
        """Transforming images on GPU"""
        image = cv2.imread(img_path) 
        image_new =  cv2.resize(image, (224,224))
        my_transforms= transforms.Compose(
            [ 
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225]) 
            ]
            )
        my_tensor = my_transforms(image_new)
        my_tensor = my_tensor.resize_(1,3,224,224)
        my_tensor= my_tensor.cuda()
        return my_tensor
    
    def cls_inference(cls_model,imgpth):
        input_tensor = deal_img(imgpth)
        cls_model.eval()
        result = cls_model(input_tensor)
        result_npy = result.data.cpu().numpy()
        max_index = np.argmax(result_npy[0])
        return max_index
    
    def feature_extract(cls_model,imgpth):
        cls_model.fc = torch.nn.LeakyReLU(0.1)
        cls_model.eval()
        input_tensor = deal_img(imgpth)
        result = cls_model(input_tensor)
        result_npy = result.data.cpu().numpy()
        return result_npy[0]
    
    
    if __name__ == "__main__":
        image_path="./pytorch/data/train/cat/08.jpg"
        model = get_model()
        cls_label = cls_inference(model,image_path)
        print(cls_label)
        feature = feature_extract(model,image_path)
        print(feature)
    

    参考

    使用pytorch预训练模型分类与特征提取 https://blog.csdn.net/u010165147/article/details/72829969?spm=1001.2014.3001.5502
  • 相关阅读:
    JavaEE三层架构
    请求重定向
    响应的中文乱码问题
    Apache的ServerAlias的作用
    bootstrap 常用class
    linux 退出当前命令的编辑
    硬链接和软链接
    ALTER TABLE causes auto_increment resequencing, resulting in duplicate entry ’1′ for key ‘PRIMARY’
    ie浏览器许多图片放在一起会有间隙
    Could not initialize class utils.JdbcUtils
  • 原文地址:https://www.cnblogs.com/ytwang/p/15239236.html
Copyright © 2011-2022 走看看