1 import os 2 import torch 3 from torchvision import transforms 4 from data_pipe import get_data 5 from vgg import VGG_13 6 from resnet18 import ResNet18 7 import numpy as np 8 import cv2 9 from PIL import Image 10 11 12 class Infer(object): 13 14 def __init__(self): 15 self.model = ResNet18() 16 self.model.load_state_dict(torch.load("./models/model_65.pth")) 17 self.model.eval() 18 self.cls = {' 0': 0, ' 1': 1, ' 10': 2, ' 11': 3, ' 12': 4, ' 13': 5, ' 14': 6, ' 15': 7, ' 16': 8, ' 17': 9, ' 18': 10, ' 19': 11, ' 2': 12, ' 20': 13, ' 21': 14, ' 22': 15, ' 23': 16, ' 24': 17, ' 25': 18, ' 26': 19, ' 27': 20, ' 28': 21, ' 29': 22, ' 3': 23, ' 30': 24, ' 31': 25, ' 32': 26, ' 33': 27, ' 34': 28, ' 35': 29, ' 36': 30, ' 37': 31, ' 38': 32, ' 39': 33, ' 4': 34, ' 5': 35, ' 6': 36, ' 7': 37, ' 8': 38, ' 9': 39} 19 self.new_cls = dict(zip(self.cls.values(), self.cls.keys())) 20 21 def _infer(self, img_tensor): 22 with torch.no_grad(): 23 result = self.model(img_tensor) 24 return result 25 26 def predict(self, path): 27 img_path_list = [os.path.join(path ,x) for x in os.listdir(path)] 28 transform = transforms.Compose([ 29 transforms.Resize([224, 224]), 30 transforms.ToTensor(), 31 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 32 for img_path in img_path_list: 33 img = cv2.imread(img_path) 34 img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 35 img_tensor = transform(img) 36 img_tensor = img_tensor.reshape((1, 3, 224, 224)) 37 result = self._infer(img_tensor) 38 _, preds = torch.max(result.data, dim = 1) 39 print(self.new_cls[preds.numpy()[0]].strip()) 40 41 42 if __name__ == "__main__": 43 path = "./test_images" 44 Infer().predict(path)