1、计算数据集的均值和方差
import os import cv2 import numpy as np from torch.utils.data import Dataset from PIL import Image def compute_mean_and_std(dataset): # 输入PyTorch的dataset,输出均值和标准差 mean_r = 0 mean_g = 0 mean_b = 0 for img, _ in dataset: img = np.asarray(img) # change PIL Image to numpy array mean_b += np.mean(img[:, :, 0]) mean_g += np.mean(img[:, :, 1]) mean_r += np.mean(img[:, :, 2]) mean_b /= len(dataset) mean_g /= len(dataset) mean_r /= len(dataset) diff_r = 0 diff_g = 0 diff_b = 0 N = 0 for img, _ in dataset: img = np.asarray(img) diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2)) diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2)) diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2)) N += np.prod(img[:, :, 0].shape) std_b = np.sqrt(diff_b / N) std_g = np.sqrt(diff_g / N) std_r = np.sqrt(diff_r / N) mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0) std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0) return mean, std
2、得到视频数据的基本信息
import cv2 video = cv2.VideoCapture(mp4_path) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(video.get(cv2.CAP_PROP_FPS)) video.release()
3、每段采样一帧视频
K = self._num_segments if is_train: if num_frames > K: # Random index for each segment. frame_indices = torch.randint( high=num_frames // K, size=(K,), dtype=torch.long) frame_indices += num_frames // K * torch.arange(K) else: frame_indices = torch.randint( high=num_frames, size=(K - num_frames,), dtype=torch.long) frame_indices = torch.sort(torch.cat(( torch.arange(num_frames), frame_indices)))[0] else: if num_frames > K: # Middle index for each segment. frame_indices = num_frames / K // 2 frame_indices += num_frames // K * torch.arange(K) else: frame_indices = torch.sort(torch.cat(( torch.arange(num_frames), torch.arange(K - num_frames))))[0] assert frame_indices.size() == (K,) return [frame_indices[i] for i in range(K)]
4、常用训练和验证数据预处理
train_transform = torchvision.transforms.Compose([ torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) val_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ])