zoukankan      html  css  js  c++  java
  • 利用Python 脚本生成 .h5 文件 代码

    利用Python 脚本生成 .h5 文件 

      1 import os, json, argparse
      2 from threading import Thread
      3 from Queue import Queue
      4 
      5 import numpy as np
      6 from scipy.misc import imread, imresize
      7 import h5py
      8 
      9 """
     10 Create an HDF5 file of images for training a feedforward style transfer model.
     11 """
     12 
     13 parser = argparse.ArgumentParser()
     14 parser.add_argument('--train_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/train2014')
     15 parser.add_argument('--val_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/val2014')
     16 parser.add_argument('--output_file', default='/media/wangxiao/WangXiao_Dataset/CoCo/coco-256.h5')
     17 parser.add_argument('--height', type=int, default=256)
     18 parser.add_argument('--width', type=int, default=256)
     19 parser.add_argument('--max_images', type=int, default=-1)
     20 parser.add_argument('--num_workers', type=int, default=2)
     21 parser.add_argument('--include_val', type=int, default=1)
     22 parser.add_argument('--max_resize', default=16, type=int)
     23 args = parser.parse_args()
     24 
     25 
     26 def add_data(h5_file, image_dir, prefix, args):
     27   # Make a list of all images in the source directory
     28   image_list = []
     29   image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG', '.png', '.PNG'}
     30   for filename in os.listdir(image_dir):
     31     ext = os.path.splitext(filename)[1]
     32     if ext in image_extensions:
     33       image_list.append(os.path.join(image_dir, filename))
     34   num_images = len(image_list)
     35 
     36   # Resize all images and copy them into the hdf5 file
     37   # We'll bravely try multithreading
     38   dset_name = os.path.join(prefix, 'images')
     39   dset_size = (num_images, 3, args.height, args.width)
     40   imgs_dset = h5_file.create_dataset(dset_name, dset_size, np.uint8)
     41   
     42   # input_queue stores (idx, filename) tuples,
     43   # output_queue stores (idx, resized_img) tuples
     44   input_queue = Queue()
     45   output_queue = Queue()
     46   
     47   # Read workers pull images off disk and resize them
     48   def read_worker():
     49     while True:
     50       idx, filename = input_queue.get()
     51       img = imread(filename)
     52       try:
     53         # First crop the image so its size is a multiple of max_resize
     54         H, W = img.shape[0], img.shape[1]
     55         H_crop = H - H % args.max_resize
     56         W_crop = W - W % args.max_resize
     57         img = img[:H_crop, :W_crop]
     58         img = imresize(img, (args.height, args.width))
     59       except (ValueError, IndexError) as e:
     60         print filename
     61         print img.shape, img.dtype
     62         print e
     63       input_queue.task_done()
     64       output_queue.put((idx, img))
     65   
     66   # Write workers write resized images to the hdf5 file
     67   def write_worker():
     68     num_written = 0
     69     while True:
     70       idx, img = output_queue.get()
     71       if img.ndim == 3:
     72         # RGB image, transpose from H x W x C to C x H x W
     73         imgs_dset[idx] = img.transpose(2, 0, 1)
     74       elif img.ndim == 2:
     75         # Grayscale image; it is H x W so broadcasting to C x H x W will just copy
     76         # grayscale values into all channels.
     77         imgs_dset[idx] = img
     78       output_queue.task_done()
     79       num_written = num_written + 1
     80       if num_written % 100 == 0:
     81         print 'Copied %d / %d images' % (num_written, num_images)
     82   
     83   # Start the read workers.
     84   for i in xrange(args.num_workers):
     85     t = Thread(target=read_worker)
     86     t.daemon = True
     87     t.start()
     88     
     89   # h5py locks internally, so we can only use a single write worker =(
     90   t = Thread(target=write_worker)
     91   t.daemon = True
     92   t.start()
     93     
     94   for idx, filename in enumerate(image_list):
     95     if args.max_images > 0 and idx >= args.max_images: break
     96     input_queue.put((idx, filename))
     97     
     98   input_queue.join()
     99   output_queue.join()
    100   
    101   
    102   
    103 if __name__ == '__main__':
    104   
    105   with h5py.File(args.output_file, 'w') as f:
    106     add_data(f, args.train_dir, 'train2014', args)
    107 
    108     if args.include_val != 0:
    109       add_data(f, args.val_dir, 'val2014', args)
  • 相关阅读:
    js内置数据类型
    vue禁止复制的方式
    阻止element组件中的<el-input/>的粘贴功能
    Vue插件集合
    qs.parse()、qs.stringify()、JSON.stringify() 用法及区别
    es6数组的一些函数方法使用
    文章段落首字母缩进两个字符
    深圳scala-meetup-20180902(3)- Using heterogeneous Monads in for-comprehension with Monad Transformer
    深圳scala-meetup-20180902(2)- Future vs Task and ReaderMonad依赖注入
    深圳scala-meetup-20180902(1)- Monadic 编程风格
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/6322318.html
Copyright © 2011-2022 走看看