zoukankan      html  css  js  c++  java
  • 深度学习代码片段合集

    计算数据集的mean和std

    """
    in this script, we calculate the image per channel mean and standard
    deviation in the training set, do not calculate the statistics on the
    whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre
    """
    
    import numpy as np
    from os import listdir
    from os.path import join, isdir
    from glob import glob
    import cv2
    import timeit
    
    # number of channels of the dataset image, 3 for color jpg, 1 for grayscale img
    # you need to change it to reflect your dataset
    CHANNEL_NUM = 3
    
    
    def cal_dir_stat(root):
        cls_dirs = [d for d in listdir(root) if isdir(join(root, d))]
        pixel_num = 0 # store all pixel number in the dataset
        channel_sum = np.zeros(CHANNEL_NUM)
        channel_sum_squared = np.zeros(CHANNEL_NUM)
    
        for idx, d in enumerate(cls_dirs):
            print("#{} class".format(idx))
            im_pths = glob(join(root, d, "*.jpg"))
    
            for path in im_pths:
                im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order
                im = im/255.0
                pixel_num += (im.size/CHANNEL_NUM)
                channel_sum += np.sum(im, axis=(0, 1))
                channel_sum_squared += np.sum(np.square(im), axis=(0, 1))
    
        bgr_mean = channel_sum / pixel_num
        bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean))
        
        # change the format from bgr to rgb
        rgb_mean = list(bgr_mean)[::-1]
        rgb_std = list(bgr_std)[::-1]
        
        return rgb_mean, rgb_std
    
    # The script assumes that under train_root, there are separate directories for each class
    # of training images.
    train_root = "/hd1/jdhao/firearm-dataset/train/"
    start = timeit.default_timer()
    mean, std = cal_dir_stat(train_root)
    end = timeit.default_timer()
    print("elapsed time: {}".format(end-start))
    print("mean:{}
    std:{}".format(mean, std))
    
  • 相关阅读:
    C# 打印多页tif
    页面动态加载js文件
    CPrintDialog 构造函数参数详解
    DEVMODE 结构体
    对C#对象的Shallow、Deep Cloning认识【转】
    PowerShell 启动应用程序【转】
    中文网页的字体
    css3自适应布局单位vw,vh你知道多少?
    微信小程序轮播图宽高计算
    更改wordpress的默认登录页面名称wp-login
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/12069512.html
Copyright © 2011-2022 走看看