zoukankan      html  css  js  c++  java
  • 深度学习代码片段合集

    计算数据集的mean和std

    """
    in this script, we calculate the image per channel mean and standard
    deviation in the training set, do not calculate the statistics on the
    whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre
    """
    
    import numpy as np
    from os import listdir
    from os.path import join, isdir
    from glob import glob
    import cv2
    import timeit
    
    # number of channels of the dataset image, 3 for color jpg, 1 for grayscale img
    # you need to change it to reflect your dataset
    CHANNEL_NUM = 3
    
    
    def cal_dir_stat(root):
        cls_dirs = [d for d in listdir(root) if isdir(join(root, d))]
        pixel_num = 0 # store all pixel number in the dataset
        channel_sum = np.zeros(CHANNEL_NUM)
        channel_sum_squared = np.zeros(CHANNEL_NUM)
    
        for idx, d in enumerate(cls_dirs):
            print("#{} class".format(idx))
            im_pths = glob(join(root, d, "*.jpg"))
    
            for path in im_pths:
                im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order
                im = im/255.0
                pixel_num += (im.size/CHANNEL_NUM)
                channel_sum += np.sum(im, axis=(0, 1))
                channel_sum_squared += np.sum(np.square(im), axis=(0, 1))
    
        bgr_mean = channel_sum / pixel_num
        bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean))
        
        # change the format from bgr to rgb
        rgb_mean = list(bgr_mean)[::-1]
        rgb_std = list(bgr_std)[::-1]
        
        return rgb_mean, rgb_std
    
    # The script assumes that under train_root, there are separate directories for each class
    # of training images.
    train_root = "/hd1/jdhao/firearm-dataset/train/"
    start = timeit.default_timer()
    mean, std = cal_dir_stat(train_root)
    end = timeit.default_timer()
    print("elapsed time: {}".format(end-start))
    print("mean:{}
    std:{}".format(mean, std))
    
  • 相关阅读:
    spring 事务
    Servlet详解之两个init方法的作用
    被request.getLocalAddr()苦闷了很久
    Java获取IP地址:request.getRemoteAddr()警惕
    MongoDB笔记
    hexo+github搭建博客
    Python处理Excel(使用openpyxl库)
    Wireshark使用学习
    查看开启操作系统端口
    记录Centos7服务器搭建过程
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/12069512.html
Copyright © 2011-2022 走看看