zoukankan      html  css  js  c++  java
  • matplotlib中plt用法实例

    import torch
    from models.models import Model
    import cv2
    from PIL import Image
    import numpy as np
    
    from matplotlib.animation import FFMpegWriter
    import time
    import matplotlib.pyplot as plt
    
    
    from torchvision.transforms import functional
    
    
    exp_name = './xxxx_results'
    dataRoot = 'xxxx.mp4'
    model_path = './checkpoint_best.pth'
    
    
    def pre_image(image):
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
        input_image = image.copy()
        # image.show()
        height, width = image.size[1], image.size[0]
        height = round(height / 16) * 16
        width = round(width / 16) * 16
        image = image.resize((width, height), Image.BILINEAR)
    
        image = functional.to_tensor(image)
        image = functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        return input_image,torch.unsqueeze(image,0)
    
    
    if __name__ == '__main__':
    
        device = torch.device('cuda:0')
    
        # load model
        model=Model()
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])
    
        model.cuda()
        model.eval()
    
        # input video
        video = cv2.VideoCapture(dataRoot)
        fps = video.get(cv2.CAP_PROP_FPS)
        print(fps)
        frameCount = video.get(cv2.CAP_PROP_FRAME_COUNT)
        print(frameCount)
        size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    
    
        # metadata = dict(title='Video Test', artist='Matplotlib', comment='Movie support!')
        # writer = FFMpegWriter(fps=25, metadata=metadata)
    
            # videoWriter = cv2.VideoWriter('trans.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)
        success, frame = video.read()
        index = 1
    
        figure = plt.figure()
        while success:
            # time1=time.time()
            src_image,frame = pre_image(frame)
            images = frame.to(device)
    
            # time1 = time.time()
    
    
            # ground truth
            # gt_path = dataRoot + '/den/' + filename_no_ext + '.csv'
    
            # predict
            dense_map,atten_map = model(images)
            # test = time.time() - time1
    
            dense_map = dense_map.cpu().data.numpy()[0,0,:,:]
            # test=time.time()-time1
    
            dense_pred_count = np.sum(dense_map)
            dense_map = dense_map/np.max(dense_map+1e-20)
    
            # cv2.imshow("image", dense_map)
            # cv2.waitKey(0)
    
    
            plt.subplot(121)
            plt.imshow(src_image)
            # plt.title('original image')
            plt.axis('off')
    
            plt.subplot(122)
            plt.imshow(dense_map)
            # plt.title('dense map')
            plt.text(25, 25, 'pred crowd count:%.4f ' % dense_pred_count, fontdict={'size': 10, 'color': 'red'})
            plt.axis('off')
    
            plt.tight_layout(pad=0.3, w_pad=0, h_pad=1)
    
            # anni=animation.FuncAnimation(fig, animate, init_func=init,frames=200, interval=20, blit=True)
            # anim.save('sin.gif', fps=75, writer='imagemagick')
            plt.savefig(exp_name + '/'+ str('%05d' % index) + '_' + str(int(dense_pred_count)) + '.png', bbox_inches='tight', pad_inches=0, dpi=150)
    
            # plt.show()
            plt.clf()
    
            success, frame = video.read()
            index += 1
    
        video.release()
    
  • 相关阅读:
    css3圆角细节
    css3伪元素
    使用vscode在谷歌上运行代码
    SpringCloud-技术专区-Gateway优雅的处理Filter抛出的异常
    SpringCloud-技术专区-Gateway全局通用异常处理
    Mybatis-技术专区-插件开发指南
    消息中间件-技术专区-RabbitMQ基本介绍
    SpringBoot-技术专区-自定义TaskExecutor线程池
    MySQL-技术专区-Binlog和Redolog的介绍
    SpringBoot-技术专区-Redis同数据源动态切换db
  • 原文地址:https://www.cnblogs.com/wangyarui/p/11201110.html
Copyright © 2011-2022 走看看