import torch from models.models import Model import cv2 from PIL import Image import numpy as np from matplotlib.animation import FFMpegWriter import time import matplotlib.pyplot as plt from torchvision.transforms import functional exp_name = './xxxx_results' dataRoot = 'xxxx.mp4' model_path = './checkpoint_best.pth' def pre_image(image): image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB)) input_image = image.copy() # image.show() height, width = image.size[1], image.size[0] height = round(height / 16) * 16 width = round(width / 16) * 16 image = image.resize((width, height), Image.BILINEAR) image = functional.to_tensor(image) image = functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) return input_image,torch.unsqueeze(image,0) if __name__ == '__main__': device = torch.device('cuda:0') # load model model=Model() checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['model']) model.cuda() model.eval() # input video video = cv2.VideoCapture(dataRoot) fps = video.get(cv2.CAP_PROP_FPS) print(fps) frameCount = video.get(cv2.CAP_PROP_FRAME_COUNT) print(frameCount) size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))) # metadata = dict(title='Video Test', artist='Matplotlib', comment='Movie support!') # writer = FFMpegWriter(fps=25, metadata=metadata) # videoWriter = cv2.VideoWriter('trans.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size) success, frame = video.read() index = 1 figure = plt.figure() while success: # time1=time.time() src_image,frame = pre_image(frame) images = frame.to(device) # time1 = time.time() # ground truth # gt_path = dataRoot + '/den/' + filename_no_ext + '.csv' # predict dense_map,atten_map = model(images) # test = time.time() - time1 dense_map = dense_map.cpu().data.numpy()[0,0,:,:] # test=time.time()-time1 dense_pred_count = np.sum(dense_map) dense_map = dense_map/np.max(dense_map+1e-20) # cv2.imshow("image", dense_map) # cv2.waitKey(0) plt.subplot(121) plt.imshow(src_image) # plt.title('original image') plt.axis('off') plt.subplot(122) plt.imshow(dense_map) # plt.title('dense map') plt.text(25, 25, 'pred crowd count:%.4f ' % dense_pred_count, fontdict={'size': 10, 'color': 'red'}) plt.axis('off') plt.tight_layout(pad=0.3, w_pad=0, h_pad=1) # anni=animation.FuncAnimation(fig, animate, init_func=init,frames=200, interval=20, blit=True) # anim.save('sin.gif', fps=75, writer='imagemagick') plt.savefig(exp_name + '/'+ str('%05d' % index) + '_' + str(int(dense_pred_count)) + '.png', bbox_inches='tight', pad_inches=0, dpi=150) # plt.show() plt.clf() success, frame = video.read() index += 1 video.release()