zoukankan      html  css  js  c++  java
  • 国庆出游神器:魔幻黑科技换天造物,让vlog秒变科幻大片!

    摘要:国庆旅游景点人太多,拍出来的照片全是人人人、车车车,该怎么办?不妨试试这个黑科技,让你的出游vlog秒变科幻大片。

    本文分享自华为云社区《国庆出游神器,魔幻黑科技换天造物,让vlog秒变科幻大片!》,作者:技术火炬手 。

    国庆出游,无论是拍人、拍景或是其他,“天空”都是关键元素。比如,一张平平无奇的景物图加上落日余晖的天空色调,氛围感就有了。

    当然,自然景观的天空还不是最酷炫的。今天给大家介绍一款基于原生视频的AI处理方法,不仅可以一键置换天空背景,还可以打造任意“天空之城”。

    比如换成《星际迷航》中的浩瀚星空、宇宙飞船,将自己随手拍的平平无奇vlog秒变为科幻大片,画面毫无违和感。

    该方法源自Github上的开源项目SkyAR,它可以自动识别天空,然后将天空从图片中切割出来,再将天空替换成目标天空,从而实现魔法换天。

    下面,我们将基于SkyAR和ModelArts的JupyterLab从零开始“换天造物”。只要脑洞够大,利用这项AI技术,就可以创造出无限种玩法。

    本案例在CPU和GPU下面均可运行,CPU环境运行预计花费9分钟,GPU环境运行预计花费2分钟。

    实验目标

    通过本案例的学习:

    了解图像分割的基本应用;

    了解运动估计的基本应用;

    了解图像混合的基本应用。

    注意事项

    1. 如果您是第一次使用 JupyterLab,请查看《ModelArts JupyterLab使用指导》了解使用方法;
    2. 如果您在使用 JupyterLab 过程中碰到报错,请参考《ModelArts JupyterLab常见问题解决办法》尝试解决问题。

    实验步骤

    1、安装和导入依赖包

    import os
    import moxing as mox
    
    file_name = 'SkyAR'
    if not os.path.exists(file_name):
        mox.file.copy('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/SkyAR.zip', 'SkyAR.zip')
        os.system('unzip SkyAR.zip')
        os.system('rm SkyAR.zip')
    mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/resnet50-19c8e357.pth', '/home/ma-user/.cache/torch/checkpoints/resnet50-19c8e357.pth')
    INFO:root:Using MoXing-v1.17.3-43fbf97f
    INFO:root:Using OBS-Python-SDK-3.20.7
    !pip uninstall opencv-python -y
    !pip uninstall opencv-contrib-python -y
    Found existing installation: opencv-python 4.1.2.30
    Uninstalling opencv-python-4.1.2.30:
      Successfully uninstalled opencv-python-4.1.2.30
    WARNING: Skipping opencv-contrib-python as it is not installed.
    !pip install opencv-contrib-python==4.5.3.56
    Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple
    Collecting opencv-contrib-python==4.5.3.56
      Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/3f/ce/36772cc6d9061b423b080e86919fd62cdef0837263f29ba6ff92e07f72d7/opencv_contrib_python-4.5.3.56-cp37-cp37m-manylinux2014_x86_64.whl (56.1 MB)
         |████████████████████████████████| 56.1 MB 166 kB/s eta 0:00:01|█████▋                          | 9.8 MB 9.4 MB/s eta 0:00:05 MB 9.4 MB/s eta 0:00:05███▏                | 26.6 MB 9.4 MB/s eta 0:00:04/s eta 0:00:03��██▍           | 35.8 MB 9.4 MB/s eta 0:00:03�███████████▌       | 42.9 MB 9.4 MB/s eta 0:00:02��██████████████▎   | 49.6 MB 166 kB/s eta 0:00:40
    Requirement already satisfied: numpy>=1.14.5 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from opencv-contrib-python==4.5.3.56) (1.20.3)
    Installing collected packages: opencv-contrib-python
    Successfully installed opencv-contrib-python-4.5.3.56
    WARNING: You are using pip version 20.3.3; however, version 21.1.3 is available.
    You should consider upgrading via the '/home/ma-user/anaconda3/envs/PyTorch-1.4/bin/python -m pip install --upgrade pip' command.
    cd SkyAR/
    /home/ma-user/work/Untitled Folder/SkyAR
    import time
    import json
    import base64
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import argparse
    from networks import *
    from skyboxengine import *
    import utils
    import torch
    from IPython.display import clear_output, Image, display, HTML
    %matplotlib inline
    
    # 如果存在GPU则在GPU上面运行
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    INFO:matplotlib.font_manager:generated new fontManager

    2、预览一下原视频

    video_name = "test_videos/sky.mp4"
    def arrayShow(img):
        img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_NEAREST)
        _,ret = cv2.imencode('.jpg', img)
        return Image(data=ret)
    
    # 打开一个视频流
    cap = cv2.VideoCapture(video_name)
    
    frame_id = 0
    while True:
        try:
            clear_output(wait=True) # 清除之前的显示
            ret, frame = cap.read() # 读取一帧图片
            if ret:
                frame_id += 1
                if frame_id > 200:
                    break
                cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)  # 画frame_id
                tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
                img = arrayShow(frame)
                display(img) # 显示图片
                time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
            else:
                break
        except KeyboardInterrupt:
            cap.release()
    cap.release()

    3、预览一下要替换的天空图片

    img= cv2.imread('skybox/sky.jpg')
    img2 = img[:,:,::-1]
    plt.imshow(img2)
    <matplotlib.image.AxesImage at 0x7fbea986c590>

    4、自定义训练参数

    可以根据自己的需要, 修改下面的参数

    skybox_center_crop: 天空体中心偏移

    auto_light_matching: 自动亮度匹配

    relighting_factor: 补光

    recoloring_factor: 重新着色

    halo_effect: 光环效应

    parameter = {
      "net_G": "coord_resnet50",
      "ckptdir": "./checkpoints_G_coord_resnet50",
    
      "input_mode": "video",
      "datadir": "./test_videos/sky.mp4",
      "skybox": "sky.jpg",
    
      "in_size_w": 384,
      "in_size_h": 384,
      "out_size_w": 845,
      "out_size_h": 480,
    
      "skybox_center_crop": 0.5,
      "auto_light_matching": False,
      "relighting_factor": 0.8,
      "recoloring_factor": 0.5,
      "halo_effect": True,
    
      "output_dir": "./jpg_output",
      "save_jpgs": False
    }
    
    str_json = json.dumps(parameter)
    class Struct:
        def __init__(self, **entries):
            self.__dict__.update(entries)
    def parse_config():
        data = json.loads(str_json)
        args = Struct(**data)
    
        return args
    args = parse_config()
    class SkyFilter():
    
        def __init__(self, args):
    
            self.ckptdir = args.ckptdir
            self.datadir = args.datadir
            self.input_mode = args.input_mode
    
            self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
            self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h
    
            self.skyboxengine = SkyBox(args)
    
            self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
            self.load_model()
    
            self.video_writer = cv2.VideoWriter('out.avi',
                                                cv2.VideoWriter_fourcc(*'MJPG'),
                                                20.0,
                                                (args.out_size_w, args.out_size_h))
            self.video_writer_cat = cv2.VideoWriter('compare.avi',
                                                    cv2.VideoWriter_fourcc(*'MJPG'),
                                                    20.0,
                                                    (2*args.out_size_w, args.out_size_h))
    
            if os.path.exists(args.output_dir) is False:
                os.mkdir(args.output_dir)
    
            self.output_img_list = []
    
            self.save_jpgs = args.save_jpgs
        def load_model(self):
            # 加载预训练的天空抠图模型
            print('loading the best checkpoint...')
            checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),
                                    map_location=device)
            self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
            self.net_G.to(device)
            self.net_G.eval()
        def write_video(self, img_HD, syneth):
    
            frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
            self.video_writer.write(frame)
    
            frame_cat = np.concatenate([img_HD, syneth], axis=1)
            frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
            self.video_writer_cat.write(frame_cat)
    
            # 定义结果缓冲区
            self.output_img_list.append(frame_cat)
        def synthesize(self, img_HD, img_HD_prev):
    
            h, w, c = img_HD.shape
    
            img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))
    
            img = np.array(img, dtype=np.float32)
            img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)
    
            with torch.no_grad():
                G_pred = self.net_G(img.to(device))
                G_pred = torch.nn.functional.interpolate(G_pred,
                                                         (h, w),
                                                         mode='bicubic',
                                                         align_corners=False)
                G_pred = G_pred[0, :].permute([1, 2, 0])
                G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
                G_pred = np.array(G_pred.detach().cpu())
                G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
    
            skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
    
            syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
    
            return syneth, G_pred, skymask
        def cvtcolor_and_resize(self, img_HD):
    
            img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
            img_HD = np.array(img_HD / 255., dtype=np.float32)
            img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))
    
            return img_HD
        def process_video(self):
            # 逐帧处理视频
            cap = cv2.VideoCapture(self.datadir)
            m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            img_HD_prev = None
    
            for idx in range(m_frames):
                ret, frame = cap.read()
                if ret:
                    img_HD = self.cvtcolor_and_resize(frame)
    
                    if img_HD_prev is None:
                        img_HD_prev = img_HD
    
                    syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)
    
                    self.write_video(img_HD, syneth)
    
                    img_HD_prev = img_HD
    
                    if (idx + 1) % 50 == 0:
                        print(f'processing video, frame {idx + 1} / {m_frames} ... ')
    
                else:  # 如果到达最后一帧
                    break

    5、替换天空

    替换后输出的视频为out.avi,前后对比的视频为compare.avi

    sf = SkyFilter(args)
    sf.process_video()
    initialize skybox...
    initialize network with normal
    loading the best checkpoint...
    processing video, frame 50 / 360 ... 
    processing video, frame 100 / 360 ... 
    no good point matched
    processing video, frame 150 / 360 ... 
    processing video, frame 200 / 360 ... 
    processing video, frame 250 / 360 ... 
    processing video, frame 300 / 360 ... 
    processing video, frame 350 / 360 ... 

    6、对比原视频和替换后的视频

    video_name = "compare.avi"
    def arrayShow(img):
        _,ret = cv2.imencode('.jpg', img)
        return Image(data=ret)
    
    # 打开一个视频流
    cap = cv2.VideoCapture(video_name)
    
    frame_id = 0
    while True:
        try:
            clear_output(wait=True) # 清除之前的显示
            ret, frame = cap.read() # 读取一帧图片
            if ret:
                frame_id += 1
                cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)  # 画frame_id
                tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
                img = arrayShow(frame)
                display(img) # 显示图片
                time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
            else:
                break
        except KeyboardInterrupt:
            cap.release()
    cap.release()

    如果要生成自己的视频,只要将test_videos中的sky.mp4视频和skybox中的sky.jpg图片替换成自己的视频和图片,然后重新一键运行就可以了。赶快来试一试吧,让你的国庆大片更出彩!

    华为云社区祝大家国庆节快乐,度过一个开心的假期!

    附录

    本案例源自华为云AI Gallery:魔幻黑科技,可换天造物,秒变科幻大片!

     

    点击关注,第一时间了解华为云新鲜技术~

  • 相关阅读:
    JAVA爬虫实践(实践三:爬虫框架webMagic和csdnBlog爬虫)
    JAVA爬虫实践(实践一:知乎)
    JAVA爬虫实践(实践二:博客园)
    SpringMVC框架学习笔记(5)——数据处理
    SpringMVC框架学习笔记——各种异常、报错解决
    SpringMVC框架学习笔记(1)——HelloWorld
    angularjs springMVC 交互
    存储过程存放数据方式
    存储过程总结
    cssie7.0兼容
  • 原文地址:https://www.cnblogs.com/huaweiyun/p/15356018.html
Copyright © 2011-2022 走看看