zoukankan      html  css  js  c++  java
  • 【colab pytorch】数据预处理

    1、计算数据集的均值和方差

    import os
    import cv2
    import numpy as np
    from torch.utils.data import Dataset
    from PIL import Image
    
    def compute_mean_and_std(dataset):
        # 输入PyTorch的dataset,输出均值和标准差
        mean_r = 0
        mean_g = 0
        mean_b = 0
    
        for img, _ in dataset:
            img = np.asarray(img) # change PIL Image to numpy array
            mean_b += np.mean(img[:, :, 0])
            mean_g += np.mean(img[:, :, 1])
            mean_r += np.mean(img[:, :, 2])
    
        mean_b /= len(dataset)
        mean_g /= len(dataset)
        mean_r /= len(dataset)
    
        diff_r = 0
        diff_g = 0
        diff_b = 0
    
        N = 0
    
        for img, _ in dataset:
            img = np.asarray(img)
    
            diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2))
            diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2))
            diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2))
    
            N += np.prod(img[:, :, 0].shape)
    
        std_b = np.sqrt(diff_b / N)
        std_g = np.sqrt(diff_g / N)
        std_r = np.sqrt(diff_r / N)
    
        mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0)
        std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0)
        return mean, std

    2、得到视频数据的基本信息

    import cv2
    video = cv2.VideoCapture(mp4_path)
    height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(video.get(cv2.CAP_PROP_FPS))
    video.release()

    3、每段采样一帧视频

    K = self._num_segments
    if is_train:
        if num_frames > K:
            # Random index for each segment.
            frame_indices = torch.randint(
                high=num_frames // K, size=(K,), dtype=torch.long)
            frame_indices += num_frames // K * torch.arange(K)
        else:
            frame_indices = torch.randint(
                high=num_frames, size=(K - num_frames,), dtype=torch.long)
            frame_indices = torch.sort(torch.cat((
                torch.arange(num_frames), frame_indices)))[0]
    else:
        if num_frames > K:
            # Middle index for each segment.
            frame_indices = num_frames / K // 2
            frame_indices += num_frames // K * torch.arange(K)
        else:
            frame_indices = torch.sort(torch.cat((                              
                torch.arange(num_frames), torch.arange(K - num_frames))))[0]
    assert frame_indices.size() == (K,)
    return [frame_indices[i] for i in range(K)]

    4、常用训练和验证数据预处理

    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(size=224,
                                                 scale=(0.08, 1.0)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225)),
     ])
     val_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225)),
    ])
  • 相关阅读:
    好的 文章链接汇总
    webpack之postcss集成
    移动端适配方法合集
    每天干的啥?(2017.3)
    每天干的啥?(2017.2)
    【2016年终总结】
    每天干的啥?(2017.1)
    PHP获取接口数据(模拟Get)
    每天干的啥?(2016.12)
    更换域名后的数据库sql的执行命令
  • 原文地址:https://www.cnblogs.com/xiximayou/p/12505923.html
Copyright © 2011-2022 走看看