zoukankan      html  css  js  c++  java
  • 深度学习代码片段合集

    计算数据集的mean和std

    """
    in this script, we calculate the image per channel mean and standard
    deviation in the training set, do not calculate the statistics on the
    whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre
    """
    
    import numpy as np
    from os import listdir
    from os.path import join, isdir
    from glob import glob
    import cv2
    import timeit
    
    # number of channels of the dataset image, 3 for color jpg, 1 for grayscale img
    # you need to change it to reflect your dataset
    CHANNEL_NUM = 3
    
    
    def cal_dir_stat(root):
        cls_dirs = [d for d in listdir(root) if isdir(join(root, d))]
        pixel_num = 0 # store all pixel number in the dataset
        channel_sum = np.zeros(CHANNEL_NUM)
        channel_sum_squared = np.zeros(CHANNEL_NUM)
    
        for idx, d in enumerate(cls_dirs):
            print("#{} class".format(idx))
            im_pths = glob(join(root, d, "*.jpg"))
    
            for path in im_pths:
                im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order
                im = im/255.0
                pixel_num += (im.size/CHANNEL_NUM)
                channel_sum += np.sum(im, axis=(0, 1))
                channel_sum_squared += np.sum(np.square(im), axis=(0, 1))
    
        bgr_mean = channel_sum / pixel_num
        bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean))
        
        # change the format from bgr to rgb
        rgb_mean = list(bgr_mean)[::-1]
        rgb_std = list(bgr_std)[::-1]
        
        return rgb_mean, rgb_std
    
    # The script assumes that under train_root, there are separate directories for each class
    # of training images.
    train_root = "/hd1/jdhao/firearm-dataset/train/"
    start = timeit.default_timer()
    mean, std = cal_dir_stat(train_root)
    end = timeit.default_timer()
    print("elapsed time: {}".format(end-start))
    print("mean:{}
    std:{}".format(mean, std))
    
  • 相关阅读:
    DOM(文档对象模型)
    客户端检测
    mysql之触发器
    mysql之select(二)
    浅谈mysql中varchar(m)与char(n)的区别与联系
    mysql之select(一)
    mysql(一)
    mysql5.7.11安装遇到的问题
    Java 网络编程(二)
    Java 网络编程(一)
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/12069512.html
Copyright © 2011-2022 走看看