zoukankan      html  css  js  c++  java
  • PyTorch——模型推断——单张推断OpenCV(一)

     1 #coding= utf-8
     2 import os
     3 import torch
     4 from data_pipe import get_data
     5 from model import SimpleNet
     6 import numpy as np
     7 import cv2
     8 from PIL import Image
     9 
    10 
    11 class Infer(object):
    12 
    13     def __init__(self):
    14         self.model = SimpleNet()
    15         self.model.load_state_dict(torch.load("./models/model_10.pth"))
    16         self.model.eval()
    17 
    18     def _infer(self, img_tensor):
    19         with torch.no_grad():
    20             result = self.model(img_tensor)
    21         if result > 0.5:
    22             result = 1
    23         else:
    24             result = 0
    25         return result
    26 
    27     def predict(self, path):
    28         img_path_list = [os.path.join(path ,x) for x in os.listdir(path)]
    29         for img_path in img_path_list:
    30             print(img_path)
    31             img = cv2.imread(img_path)
    32             img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
    33             img_tensor = torch.from_numpy(np.asarray(img)).permute(2,0,1).float()/255.0
    34             img_tensor = img_tensor.reshape((1, 3, 32, 32))
    35             result = self._infer(img_tensor)
    36             print(result)
    37 
    38 
    39 if __name__ == "__main__":
    40     path = "./test_images"
    41     Infer().predict(path)
  • 相关阅读:
    Python 编码格式的使用
    解决cmd 运行python socket怎么终止运行
    解决win10子系统Ubuntu新装的mysql 不能root登陆方法
    浏览器中网址访问过程详解
    POJ 2502 最短路
    HDU 2859
    POJ 3186
    POJ 1661 暴力dp
    POJ 1015 陪审团问题
    CodeForces 1058E
  • 原文地址:https://www.cnblogs.com/timelesszxl/p/14595903.html
Copyright © 2011-2022 走看看