zoukankan      html  css  js  c++  java
  • matplotlib中plt用法实例

    import torch
    from models.models import Model
    import cv2
    from PIL import Image
    import numpy as np
    
    from matplotlib.animation import FFMpegWriter
    import time
    import matplotlib.pyplot as plt
    
    
    from torchvision.transforms import functional
    
    
    exp_name = './xxxx_results'
    dataRoot = 'xxxx.mp4'
    model_path = './checkpoint_best.pth'
    
    
    def pre_image(image):
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
        input_image = image.copy()
        # image.show()
        height, width = image.size[1], image.size[0]
        height = round(height / 16) * 16
        width = round(width / 16) * 16
        image = image.resize((width, height), Image.BILINEAR)
    
        image = functional.to_tensor(image)
        image = functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        return input_image,torch.unsqueeze(image,0)
    
    
    if __name__ == '__main__':
    
        device = torch.device('cuda:0')
    
        # load model
        model=Model()
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])
    
        model.cuda()
        model.eval()
    
        # input video
        video = cv2.VideoCapture(dataRoot)
        fps = video.get(cv2.CAP_PROP_FPS)
        print(fps)
        frameCount = video.get(cv2.CAP_PROP_FRAME_COUNT)
        print(frameCount)
        size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    
    
        # metadata = dict(title='Video Test', artist='Matplotlib', comment='Movie support!')
        # writer = FFMpegWriter(fps=25, metadata=metadata)
    
            # videoWriter = cv2.VideoWriter('trans.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)
        success, frame = video.read()
        index = 1
    
        figure = plt.figure()
        while success:
            # time1=time.time()
            src_image,frame = pre_image(frame)
            images = frame.to(device)
    
            # time1 = time.time()
    
    
            # ground truth
            # gt_path = dataRoot + '/den/' + filename_no_ext + '.csv'
    
            # predict
            dense_map,atten_map = model(images)
            # test = time.time() - time1
    
            dense_map = dense_map.cpu().data.numpy()[0,0,:,:]
            # test=time.time()-time1
    
            dense_pred_count = np.sum(dense_map)
            dense_map = dense_map/np.max(dense_map+1e-20)
    
            # cv2.imshow("image", dense_map)
            # cv2.waitKey(0)
    
    
            plt.subplot(121)
            plt.imshow(src_image)
            # plt.title('original image')
            plt.axis('off')
    
            plt.subplot(122)
            plt.imshow(dense_map)
            # plt.title('dense map')
            plt.text(25, 25, 'pred crowd count:%.4f ' % dense_pred_count, fontdict={'size': 10, 'color': 'red'})
            plt.axis('off')
    
            plt.tight_layout(pad=0.3, w_pad=0, h_pad=1)
    
            # anni=animation.FuncAnimation(fig, animate, init_func=init,frames=200, interval=20, blit=True)
            # anim.save('sin.gif', fps=75, writer='imagemagick')
            plt.savefig(exp_name + '/'+ str('%05d' % index) + '_' + str(int(dense_pred_count)) + '.png', bbox_inches='tight', pad_inches=0, dpi=150)
    
            # plt.show()
            plt.clf()
    
            success, frame = video.read()
            index += 1
    
        video.release()
    
  • 相关阅读:
    js获取base64格式图片预览上传并用php保存到本地服务器指定文件夹
    matplotlib等值线显示
    Matplotlib调用imshow()函数绘制热图
    tensorflow 卷积神经网络预测手写 数字
    tensorflow 参数初始化
    matplotlib 读取图形数据
    tensorflow载入数据的三种方式
    tf.get_variable函数的使用
    TF-卷积函数 tf.nn.conv2d 介绍
    Git 常用命令
  • 原文地址:https://www.cnblogs.com/wangyarui/p/11201110.html
Copyright © 2011-2022 走看看