zoukankan      html  css  js  c++  java
  • matplotlib中plt用法实例

    import torch
    from models.models import Model
    import cv2
    from PIL import Image
    import numpy as np
    
    from matplotlib.animation import FFMpegWriter
    import time
    import matplotlib.pyplot as plt
    
    
    from torchvision.transforms import functional
    
    
    exp_name = './xxxx_results'
    dataRoot = 'xxxx.mp4'
    model_path = './checkpoint_best.pth'
    
    
    def pre_image(image):
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
        input_image = image.copy()
        # image.show()
        height, width = image.size[1], image.size[0]
        height = round(height / 16) * 16
        width = round(width / 16) * 16
        image = image.resize((width, height), Image.BILINEAR)
    
        image = functional.to_tensor(image)
        image = functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        return input_image,torch.unsqueeze(image,0)
    
    
    if __name__ == '__main__':
    
        device = torch.device('cuda:0')
    
        # load model
        model=Model()
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])
    
        model.cuda()
        model.eval()
    
        # input video
        video = cv2.VideoCapture(dataRoot)
        fps = video.get(cv2.CAP_PROP_FPS)
        print(fps)
        frameCount = video.get(cv2.CAP_PROP_FRAME_COUNT)
        print(frameCount)
        size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    
    
        # metadata = dict(title='Video Test', artist='Matplotlib', comment='Movie support!')
        # writer = FFMpegWriter(fps=25, metadata=metadata)
    
            # videoWriter = cv2.VideoWriter('trans.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)
        success, frame = video.read()
        index = 1
    
        figure = plt.figure()
        while success:
            # time1=time.time()
            src_image,frame = pre_image(frame)
            images = frame.to(device)
    
            # time1 = time.time()
    
    
            # ground truth
            # gt_path = dataRoot + '/den/' + filename_no_ext + '.csv'
    
            # predict
            dense_map,atten_map = model(images)
            # test = time.time() - time1
    
            dense_map = dense_map.cpu().data.numpy()[0,0,:,:]
            # test=time.time()-time1
    
            dense_pred_count = np.sum(dense_map)
            dense_map = dense_map/np.max(dense_map+1e-20)
    
            # cv2.imshow("image", dense_map)
            # cv2.waitKey(0)
    
    
            plt.subplot(121)
            plt.imshow(src_image)
            # plt.title('original image')
            plt.axis('off')
    
            plt.subplot(122)
            plt.imshow(dense_map)
            # plt.title('dense map')
            plt.text(25, 25, 'pred crowd count:%.4f ' % dense_pred_count, fontdict={'size': 10, 'color': 'red'})
            plt.axis('off')
    
            plt.tight_layout(pad=0.3, w_pad=0, h_pad=1)
    
            # anni=animation.FuncAnimation(fig, animate, init_func=init,frames=200, interval=20, blit=True)
            # anim.save('sin.gif', fps=75, writer='imagemagick')
            plt.savefig(exp_name + '/'+ str('%05d' % index) + '_' + str(int(dense_pred_count)) + '.png', bbox_inches='tight', pad_inches=0, dpi=150)
    
            # plt.show()
            plt.clf()
    
            success, frame = video.read()
            index += 1
    
        video.release()
    
  • 相关阅读:
    磁盘管理
    TCP/IP四层模型
    OSI七层模型详解
    kvm虚拟机
    mount 文件挂载
    ORA-01017: 用户名/口令无效; 登录被拒绝
    mybatis配置文件形式
    Spring+mybatis整合
    xmlBean学习二
    xmlBean学习一
  • 原文地址:https://www.cnblogs.com/wangyarui/p/11201110.html
Copyright © 2011-2022 走看看