zoukankan      html  css  js  c++  java
  • 深度学习代码片段合集

    计算数据集的mean和std

    """
    in this script, we calculate the image per channel mean and standard
    deviation in the training set, do not calculate the statistics on the
    whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre
    """
    
    import numpy as np
    from os import listdir
    from os.path import join, isdir
    from glob import glob
    import cv2
    import timeit
    
    # number of channels of the dataset image, 3 for color jpg, 1 for grayscale img
    # you need to change it to reflect your dataset
    CHANNEL_NUM = 3
    
    
    def cal_dir_stat(root):
        cls_dirs = [d for d in listdir(root) if isdir(join(root, d))]
        pixel_num = 0 # store all pixel number in the dataset
        channel_sum = np.zeros(CHANNEL_NUM)
        channel_sum_squared = np.zeros(CHANNEL_NUM)
    
        for idx, d in enumerate(cls_dirs):
            print("#{} class".format(idx))
            im_pths = glob(join(root, d, "*.jpg"))
    
            for path in im_pths:
                im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order
                im = im/255.0
                pixel_num += (im.size/CHANNEL_NUM)
                channel_sum += np.sum(im, axis=(0, 1))
                channel_sum_squared += np.sum(np.square(im), axis=(0, 1))
    
        bgr_mean = channel_sum / pixel_num
        bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean))
        
        # change the format from bgr to rgb
        rgb_mean = list(bgr_mean)[::-1]
        rgb_std = list(bgr_std)[::-1]
        
        return rgb_mean, rgb_std
    
    # The script assumes that under train_root, there are separate directories for each class
    # of training images.
    train_root = "/hd1/jdhao/firearm-dataset/train/"
    start = timeit.default_timer()
    mean, std = cal_dir_stat(train_root)
    end = timeit.default_timer()
    print("elapsed time: {}".format(end-start))
    print("mean:{}
    std:{}".format(mean, std))
    
  • 相关阅读:
    hph 缓存机制
    递归调用 和 迭代
    多维数组排序
    php curl操作
    JavaScript基本数据类型
    JavaScript基础
    CSS基础布局
    CSS基础样式
    CSS选择器
    CSS3基础
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/12069512.html
Copyright © 2011-2022 走看看