1 import os, sys, glob, shutil, json 2 3 os.environ["CUDA_VISIBLE_DEVICES"] = '0' 4 import cv2 5 from PIL import Image 6 import numpy as np 7 from tqdm import tqdm, tqdm_notebook 8 import torch 9 10 torch.manual_seed(0) 11 torch.backends.cudnn.deterministic = False 12 torch.backends.cudnn.benchmark = True 13 import torchvision.models as models 14 import torchvision.transforms as transforms 15 import torchvision.datasets as datasets 16 import torch.nn as nn 17 import torch.nn.functional as F 18 import torch.optim as optim 19 from torch.autograd import Variable 20 from torch.utils.data.dataset import Dataset 21 22 23 # 定义读取数据集 24 class SVHNDataset(Dataset): 25 def __init__(self, img_path, img_label, transform=None): 26 self.img_path = img_path 27 self.img_label = img_label 28 if transform is not None: 29 self.transform = transform 30 else: 31 self.transform = None 32 33 def __getitem__(self, index): 34 img = Image.open(self.img_path[index]).convert('RGB') 35 36 if self.transform is not None: 37 img = self.transform(img) 38 39 lbl = np.array(self.img_label[index], dtype=np.int) 40 lbl = list(lbl) + (5 - len(lbl)) * [10] 41 return img, torch.from_numpy(np.array(lbl[:5])) 42 43 def __len__(self): 44 return len(self.img_path) 45 46 47 # 这里使用ResNet18的模型进行特征提取 48 class SVHN_Model1(nn.Module): 49 def __init__(self): 50 super(SVHN_Model1, self).__init__() 51 model_conv = models.resnet18(pretrained=True) 52 model_conv.avgpool = nn.AdaptiveAvgPool2d(1) 53 model_conv = nn.Sequential(*list(model_conv.children())[:-1]) 54 self.cnn = model_conv 55 56 self.fc1 = nn.Linear(512, 11) 57 self.fc2 = nn.Linear(512, 11) 58 self.fc3 = nn.Linear(512, 11) 59 self.fc4 = nn.Linear(512, 11) 60 self.fc5 = nn.Linear(512, 11) 61 62 def forward(self, img): 63 feat = self.cnn(img) 64 # print(feat.shape) 65 feat = feat.view(feat.shape[0], -1) 66 c1 = self.fc1(feat) 67 c2 = self.fc2(feat) 68 c3 = self.fc3(feat) 69 c4 = self.fc4(feat) 70 c5 = self.fc5(feat) 71 return c1, c2, c3, c4, c5 72 73 74 def predict(test_loader_, model_, tta=10): 75 model_.eval() 76 test_pred_tta = None 77 78 use_cuda = True 79 80 # TTA 次数 81 for _ in range(tta): 82 test_pred = [] 83 84 with torch.no_grad(): 85 for i, (input, target) in enumerate(test_loader_): 86 if use_cuda: 87 input = input.cuda() 88 89 c0, c1, c2, c3, c4 = model(input) 90 if use_cuda: 91 output = np.concatenate([ 92 c0.data.cpu().numpy(), 93 c1.data.cpu().numpy(), 94 c2.data.cpu().numpy(), 95 c3.data.cpu().numpy(), 96 c4.data.cpu().numpy()], axis=1) 97 else: 98 output = np.concatenate([ 99 c0.data.numpy(), 100 c1.data.numpy(), 101 c2.data.numpy(), 102 c3.data.numpy(), 103 c4.data.numpy()], axis=1) 104 105 test_pred.append(output) 106 107 test_pred = np.vstack(test_pred) 108 if test_pred_tta is None: 109 test_pred_tta = test_pred 110 else: 111 test_pred_tta += test_pred 112 113 return test_pred_tta 114 115 116 if __name__ == '__main__': 117 # ----------------------------------------------【加载数据和模型】----------------------------------------------------------- 118 test_path = glob.glob('mchar_test_a/mchar_test_a/*.png') 119 #test_path = glob.glob('FUCK/*.png') 120 test_path.sort() 121 test_label = [[1]] * len(test_path) 122 print(len(test_path), len(test_label)) 123 124 test_loader = torch.utils.data.DataLoader( 125 SVHNDataset(test_path, test_label, 126 transforms.Compose([ 127 transforms.Resize((64, 128)), 128 transforms.RandomCrop((60, 120)), 129 transforms.ColorJitter(0.3, 0.3, 0.2), 130 transforms.RandomRotation(10), 131 transforms.ToTensor(), 132 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 133 ])), 134 batch_size=40, 135 shuffle=False, 136 num_workers=10, 137 ) 138 model = SVHN_Model1() 139 140 # 加载训练模型 141 model.load_state_dict(torch.load('model.pt')) 142 143 # 如果不加这一句,将会导致: 144 # predict函数中, 这一句:c0, c1, c2, c3, c4 = model(input) 报错 145 # 报错信息:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same 146 model = model.cuda() 147 148 # 预测 149 test_predict_label = predict(test_loader, model, 1) 150 151 # 处理label 152 test_predict_label = np.vstack([ 153 test_predict_label[:, :11].argmax(1), 154 test_predict_label[:, 11:22].argmax(1), 155 test_predict_label[:, 22:33].argmax(1), 156 test_predict_label[:, 33:44].argmax(1), 157 test_predict_label[:, 44:55].argmax(1), 158 ]).T 159 160 test_label_pred = [] 161 for x in test_predict_label: 162 test_label_pred.append(''.join(map(str, x[x != 10]))) 163 164 # 写入文件 165 import pandas as pd 166 df_submit = pd.read_csv('mchar_sample_submit_A.csv') 167 df_submit['file_code'] = test_label_pred 168 df_submit.to_csv('submit.csv', index=None) 169 170 171 172 print('---') 173 print()
下面是最后 submit.csv文件中的部分内容
1 file_name file_code 2 000000.png 5 3 000001.png 290 4 000002.png 155 5 000003.png 97 6 000004.png 63 7 000005.png 399 8 000006.png 226 9 000007.png 1471 10 000008.png 4 11 ...