zoukankan      html  css  js  c++  java
  • matplotlib中plt用法实例

    import torch
    from models.models import Model
    import cv2
    from PIL import Image
    import numpy as np
    
    from matplotlib.animation import FFMpegWriter
    import time
    import matplotlib.pyplot as plt
    
    
    from torchvision.transforms import functional
    
    
    exp_name = './xxxx_results'
    dataRoot = 'xxxx.mp4'
    model_path = './checkpoint_best.pth'
    
    
    def pre_image(image):
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
        input_image = image.copy()
        # image.show()
        height, width = image.size[1], image.size[0]
        height = round(height / 16) * 16
        width = round(width / 16) * 16
        image = image.resize((width, height), Image.BILINEAR)
    
        image = functional.to_tensor(image)
        image = functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        return input_image,torch.unsqueeze(image,0)
    
    
    if __name__ == '__main__':
    
        device = torch.device('cuda:0')
    
        # load model
        model=Model()
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])
    
        model.cuda()
        model.eval()
    
        # input video
        video = cv2.VideoCapture(dataRoot)
        fps = video.get(cv2.CAP_PROP_FPS)
        print(fps)
        frameCount = video.get(cv2.CAP_PROP_FRAME_COUNT)
        print(frameCount)
        size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    
    
        # metadata = dict(title='Video Test', artist='Matplotlib', comment='Movie support!')
        # writer = FFMpegWriter(fps=25, metadata=metadata)
    
            # videoWriter = cv2.VideoWriter('trans.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)
        success, frame = video.read()
        index = 1
    
        figure = plt.figure()
        while success:
            # time1=time.time()
            src_image,frame = pre_image(frame)
            images = frame.to(device)
    
            # time1 = time.time()
    
    
            # ground truth
            # gt_path = dataRoot + '/den/' + filename_no_ext + '.csv'
    
            # predict
            dense_map,atten_map = model(images)
            # test = time.time() - time1
    
            dense_map = dense_map.cpu().data.numpy()[0,0,:,:]
            # test=time.time()-time1
    
            dense_pred_count = np.sum(dense_map)
            dense_map = dense_map/np.max(dense_map+1e-20)
    
            # cv2.imshow("image", dense_map)
            # cv2.waitKey(0)
    
    
            plt.subplot(121)
            plt.imshow(src_image)
            # plt.title('original image')
            plt.axis('off')
    
            plt.subplot(122)
            plt.imshow(dense_map)
            # plt.title('dense map')
            plt.text(25, 25, 'pred crowd count:%.4f ' % dense_pred_count, fontdict={'size': 10, 'color': 'red'})
            plt.axis('off')
    
            plt.tight_layout(pad=0.3, w_pad=0, h_pad=1)
    
            # anni=animation.FuncAnimation(fig, animate, init_func=init,frames=200, interval=20, blit=True)
            # anim.save('sin.gif', fps=75, writer='imagemagick')
            plt.savefig(exp_name + '/'+ str('%05d' % index) + '_' + str(int(dense_pred_count)) + '.png', bbox_inches='tight', pad_inches=0, dpi=150)
    
            # plt.show()
            plt.clf()
    
            success, frame = video.read()
            index += 1
    
        video.release()
    
  • 相关阅读:
    序列化和反序列化&持久化
    基于qiankun微前端的部署方案
    【MySQL】Explain执行计划 type类型说明
    【ElasticSearch】index read-only
    【MybatisPlus】Wrappers条件构造器构造or条件查询
    【布隆过滤器】基于Resisson的实现的布隆过滤器
    Nacos源码分析(三): 心跳设计
    Nacos源码分析(二):服务端和客户端实例注册
    Nacos源码分析(一): Nacos源码环境搭建
    【linux】如何在linux中查找文件、进程
  • 原文地址:https://www.cnblogs.com/wangyarui/p/11201110.html
Copyright © 2011-2022 走看看