zoukankan      html  css  js  c++  java
  • PyTorch——模型推断——单张推断OpenCV(一)

     1 #coding= utf-8
     2 import os
     3 import torch
     4 from data_pipe import get_data
     5 from model import SimpleNet
     6 import numpy as np
     7 import cv2
     8 from PIL import Image
     9 
    10 
    11 class Infer(object):
    12 
    13     def __init__(self):
    14         self.model = SimpleNet()
    15         self.model.load_state_dict(torch.load("./models/model_10.pth"))
    16         self.model.eval()
    17 
    18     def _infer(self, img_tensor):
    19         with torch.no_grad():
    20             result = self.model(img_tensor)
    21         if result > 0.5:
    22             result = 1
    23         else:
    24             result = 0
    25         return result
    26 
    27     def predict(self, path):
    28         img_path_list = [os.path.join(path ,x) for x in os.listdir(path)]
    29         for img_path in img_path_list:
    30             print(img_path)
    31             img = cv2.imread(img_path)
    32             img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
    33             img_tensor = torch.from_numpy(np.asarray(img)).permute(2,0,1).float()/255.0
    34             img_tensor = img_tensor.reshape((1, 3, 32, 32))
    35             result = self._infer(img_tensor)
    36             print(result)
    37 
    38 
    39 if __name__ == "__main__":
    40     path = "./test_images"
    41     Infer().predict(path)
  • 相关阅读:
    Mina Core 10-执行器过滤器
    Mina Core 09-编解码过滤器
    Mina Core 08-IoBuffer
    Mina Basics 07-处理程序Handler
    Mina Basics 06-传输
    Mina Basics 05-过滤器
    Mina Basics 04- 会话
    Mina Basics 03-IoService
    Mina Basics 02-基础
    Mina Basics 01- 入门
  • 原文地址:https://www.cnblogs.com/timelesszxl/p/14595903.html
Copyright © 2011-2022 走看看