zoukankan      html  css  js  c++  java
  • 【猫狗数据集】计算数据集的平均值和方差

    数据集下载地址:

    链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
    提取码:2xq4

    创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html

    读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html

    进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html

    保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html

    加载保存的模型并测试:https://www.cnblogs.com/xiximayou/p/12459499.html

    划分验证集并边训练边验证:https://www.cnblogs.com/xiximayou/p/12464738.html

    使用学习率衰减策略并边训练边测试:https://www.cnblogs.com/xiximayou/p/12468010.html

    利用tensorboard可视化训练和测试过程:https://www.cnblogs.com/xiximayou/p/12482573.html

    从命令行接收参数:https://www.cnblogs.com/xiximayou/p/12488662.html

    使用top1和top5准确率来衡量模型:https://www.cnblogs.com/xiximayou/p/12489069.html

    使用预训练的resnet18模型:https://www.cnblogs.com/xiximayou/p/12504579.html

    epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html

    计算数据集的均值和方差有两种方式:

    方法一:在utils下新建一个count_mean_std.py文件

    import os
    import cv2
    import numpy as np
    from torch.utils.data import Dataset
    from PIL import Image
    import torchvision
    import time
    from time import time 
    from tqdm import tqdm
    
    def compute_mean_and_std(dataset):
        # 输入PyTorch的dataset,输出均值和标准差
        mean_r = 0
        mean_g = 0
        mean_b = 0
        print("计算均值>>>")
        for img_path, _ in tqdm(dataset,ncols=80):
          img=Image.open(img_path)
          img = np.asarray(img) # change PIL Image to numpy array
          mean_b += np.mean(img[:, :, 0])
          mean_g += np.mean(img[:, :, 1])
          mean_r += np.mean(img[:, :, 2])
    
        mean_b /= len(dataset)
        mean_g /= len(dataset)
        mean_r /= len(dataset)
    
        diff_r = 0
        diff_g = 0
        diff_b = 0
    
        N = 0
        print("计算方差>>>")
        for img_path, _ in tqdm(dataset,ncols=80):
          img=Image.open(img_path)
          img = np.asarray(img)
          diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2))
          diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2))
          diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2))
    
          N += np.prod(img[:, :, 0].shape)
    
        std_b = np.sqrt(diff_b / N)
        std_g = np.sqrt(diff_g / N)
        std_r = np.sqrt(diff_r / N)
    
        mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0)
        std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0)
        return mean, std
    path = "/content/drive/My Drive/colab notebooks/data/dogcat"
    train_path=path+"/train"
    test_path=path+"/test"
    val_path=path+'/val'
    train_data = torchvision.datasets.ImageFolder(train_path)
    val_data = torchvision.datasets.ImageFolder(val_path)
    test_data = torchvision.datasets.ImageFolder(test_path)
    #train_mean,train_std=compute_mean_and_std(train_data.imgs)
    time_start =time()
    val_mean,val_std=compute_mean_and_std(val_data.imgs)
    time_end=time()
    print("验证集计算消耗时间:", round(time_end - time_start, 4), "s")
    #test_mean,test_std=compute_mean_and_std(test_data.imgs)
    #print("训练集的平均值:{},方差:{}".format(train_mean,train_std))
    print("验证集的平均值:{}".format(val_mean))
    print("验证集的方差:{}".format(val_mean))
    #print("测试集的平均值:{},方差:{}".format(test_mean,test_std))

    输出的时候输出错了:应该是

    print("验证集的方差:{}".format(val_std))

    结果:

    说明:由于我们是使用pytorch的datasets.ImageFolder 读取数据集。为了传入图片,我们需要使用train_data.imgs类似的操作取出图片。train_data.imgs的值是[(图片地址1,标签),(图片地址2,标签),...]的格式。在代码中for img_path,_ in dataset正好取出图片的地址。再使用Image.open()打开一张图片,转换成numpy格式,最后计算均值和方差。别看图中速度还是很快的,其实这是我运行几次的结果,数据是从缓存中获取的,第一次运行的时候速度会很慢。这里只对验证集进行了计算,训练集有接近2万张图片,就更慢了,就不计算了。

    得到均值和方差之后,在数据增强时可以这么使用:

    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(size=224,
                                                 scale=(0.08, 1.0)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225)),
     ])
     val_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225)),
    ])

    注意标准化是放在所有数据增强最后的。因为之前对数据增强是对图片而言。这些操作都会在ToTensor()操作之前。进行了ToTensor()操作之后,像素点的值会在0-1之间了,而且是张量。

    方法二:

    import numpy as np
    import cv2
    import random
     
    # calculate means and std
    train_txt_path = './train_val_list.txt'
     
    CNum = 10000   # 挑选多少图片进行计算
     
    img_h, img_w = 224, 224
    imgs = np.zeros([img_w, img_h, 3, 1])
    means, stdevs = [], []
     
    with open(train_txt_path, 'r') as f:
      lines = f.readlines()
      random.shuffle(lines)  # shuffle , 随机挑选图片
     
      for i in tqdm_notebook(range(CNum)):
        img_path = os.path.join('./train', lines[i].rstrip().split()[0])
     
        img = cv2.imread(img_path)
        img = cv2.resize(img, (img_h, img_w))
        img = img[:, :, :, np.newaxis]
        
        imgs = np.concatenate((imgs, img), axis=3)
    #     print(i)
     
    imgs = imgs.astype(np.float32)/255.
     
     
    for i in tqdm_notebook(range(3)):
      pixels = imgs[:,:,i,:].ravel() # 拉成一行
      means.append(np.mean(pixels))
      stdevs.append(np.std(pixels))
     
    # cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转
    means.reverse() # BGR --> RGB
    stdevs.reverse()
     
    print("normMean = {}".format(means))
    print("normStd = {}".format(stdevs))
    print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))

    从网上摘的,供参考

    之前我们都是利用datasets.ImageFolder读取数据集,下一节我们使用第二种方式读取猫狗数据集。

  • 相关阅读:
    vsc连接wsl时node进程占用cpu高
    LifeCycles属性
    让kbmMWClientQuery更新视图
    uniGUI免登录的实现
    uniGUI 快速定制手机端输入界面布局
    更新IDE的背景
    Delphi 10.4.1来了
    如何修改windows服务器最大的tcp连接数
    uniGUI怎么升级jquery
    【转】UniGUI的布局使用说明
  • 原文地址:https://www.cnblogs.com/xiximayou/p/12507149.html
Copyright © 2011-2022 走看看