zoukankan      html  css  js  c++  java
  • 将数据转为tfrecord格式

    假设emo文件夹下,有1,2,3,4等文件夹,每个文件夹代表一个类别

      1 import tensorflow as tf
      2 from PIL import Image
      3 from glob import glob
      4 import os
      5 import progressbar
      6 import time
      7 
      8 
      9 class TFRecord():
     10     def __init__(self, path=None, tfrecord_file=None):
     11         self.path = path
     12         self.tfrecord_file = tfrecord_file
     13 
     14     def _convert_image(self, idx, img_path, is_training=True):
     15         label = idx
     16 
     17         with tf.gfile.FastGFile(img_path, 'rb') as fid:
     18             img_str = fid.read()
     19 
     20         # img_data = Image.open(img_path)
     21         # img_data = img_data.resize((224, 224))
     22         # img_str = img_data.tobytes()
     23 
     24         file_name = img_path
     25 
     26         if is_training:
     27             feature_key_value_pair = {
     28                 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(
     29                     value=[file_name.encode()])),
     30                 'img': tf.train.Feature(bytes_list=tf.train.BytesList(
     31                     value=[img_str])),
     32                 'label': tf.train.Feature(int64_list=tf.train.Int64List(
     33                     value=[label]))
     34             }
     35         else:
     36             feature_key_value_pair = {
     37                 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(
     38                     value=[file_name.encode()])),
     39                 'img': tf.train.Feature(bytes_list=tf.train.BytesList(
     40                     value=[img_str])),
     41                 'label': tf.train.Feature(int64_list=tf.train.Int64List(
     42                     value=[-1]))
     43             }
     44 
     45         feature = tf.train.Features(feature=feature_key_value_pair)
     46         example = tf.train.Example(features=feature)
     47         return example
     48 
     49     def convert_img_folder(self):
     50 
     51         folder_path = self.path
     52         tfrecord_path = self.tfrecord_file
     53         img_paths = []
     54         for file in os.listdir(folder_path):
     55             for img_path in os.listdir(os.path.join(folder_path, file)):
     56                 img_paths.append(os.path.join(folder_path, file, img_path))
     57 
     58 
     59         with tf.python_io.TFRecordWriter(tfrecord_path) as tfwrite:
     60             widgets = ['[INFO] write image to tfrecord: ', progressbar.Percentage(), " ",
     61                        progressbar.Bar(), " ", progressbar.ETA()]
     62             pbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=widgets).start()
     63 
     64             cate = [folder_path + '/' + x for x in os.listdir(folder_path) if
     65                     os.path.isdir(folder_path + '/' + x)]
     66 
     67             i = 0
     68             for idx, folder in enumerate(cate):
     69                 for img_path in glob(folder + '/*.jpg'):
     70                     example = self._convert_image(idx, img_path)
     71                     tfwrite.write(example.SerializeToString())
     72                     pbar.update(i)
     73                     i += 1
     74 
     75             pbar.finish()
     76 
     77     def _extract_fn(self, tfrecord):
     78         feautres = {
     79             'file_name': tf.FixedLenFeature([], tf.string),
     80             'img': tf.FixedLenFeature([], tf.string),
     81             'label': tf.FixedLenFeature([], tf.int64)
     82         }
     83         sample = tf.parse_single_example(tfrecord, feautres)
     84         img = tf.image.decode_jpeg(sample['img'])
     85         img = tf.image.resize_images(img, (224, 224), method=1)
     86         label = sample['label']
     87         file_name = sample['file_name']
     88         return [img, label, file_name]
     89 
     90     def extract_image(self, shuffle_size, batch_size):
     91         dataset = tf.data.TFRecordDataset([self.tfrecord_file])
     92         dataset = dataset.map(self._extract_fn)
     93         dataset = dataset.shuffle(shuffle_size).batch(batch_size)
     94         print("---------", type(dataset))
     95         return dataset
     96 
     97 
     98 
     99 
    100 if __name__=='__main__':
    101 
    102     # start = time.time()
    103     # t = GenerateTFRecord('/')
    104     # t.convert_img_folder('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')
    105     # print("Took %f seconds." % (time.time() - start))
    106 
    107     t =TFRecord('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')
    108     t.convert_img_folder()
    109     dataset = t.extract_image(100, 64)
    110     for(batch, batch_data) in enumerate(dataset):
    111         data, label, _ = batch_data
    112         print(label)
    113         print(data.shape)
    View Code

    ps: tf.enable_eager_execution()

         tf.__version__==1.8.0

    参考:https://zhuanlan.zhihu.com/p/30751039

              https://lonepatient.top/2018/06/01/tensorflow_tfrecord.html

             https://zhuanlan.zhihu.com/p/51186668

  • 相关阅读:
    c语言博客作业04--数组
    C博客作业03--函数
    c博客作业02--循环结构
    C博客作业01--顺序分支结构
    我的第一篇博客
    java--购物车程序的面向对象设计
    c博客作业05--指针
    C博客作业04--数组
    C博客作业03--函数
    C博客作业02--循环结构
  • 原文地址:https://www.cnblogs.com/573177885qq/p/11186023.html
Copyright © 2011-2022 走看看