zoukankan      html  css  js  c++  java
  • 读取tfrecord 代码——可用任意照片均可2

    代码

      1 # -*- coding: utf-8 -*-
      2 # @Time    : 2018/12/1 11:06
      3 # @Author  : MaochengHu
      4 # @Email   : wojiaohumaocheng@gmail.com
      5 # @File    : read_tfrecord.py
      6 # @Software: PyCharm
      7 import os
      8 import tensorflow as tf
      9 flags = tf.app.flags
     10 flags.DEFINE_string('tfrecord_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record',
     11                     'path to tfrecord file')
     12 flags.DEFINE_integer('resize_height', 800, 'resize height of image')
     13 flags.DEFINE_integer('resize_width', 800, 'resize width of image')
     14 FLAG = flags.FLAGS
     15 slim = tf.contrib.slim
     16 
     17 def print_data(image, resized_image, label, height, width):
     18     with tf.Session() as sess:
     19         init_op = tf.global_variables_initializer()
     20         sess.run(init_op)
     21         coord = tf.train.Coordinator()
     22         threads = tf.train.start_queue_runners(coord=coord)
     23         for i in range(20):
     24             print("______________________image({})___________________".format(i))
     25             print_image, print_resized_image, print_label, print_height, print_width = sess.run(
     26                 [image, resized_image, label, height, width])
     27             print("resized_image shape is: ", print_resized_image.shape)
     28             print("image shape is: ", print_image.shape)
     29             print("image label is: ", print_label)
     30             print("image height is: ", print_height)
     31             print("image width is: ", print_width)
     32         coord.request_stop()
     33         coord.join(threads)
     34 
     35 def reshape_same_size(image, output_height, output_width):
     36     """Resize images by fixed sides.
     37 
     38     Args:
     39         image: A 3-D image `Tensor`.
     40         output_height: The height of the image after preprocessing.
     41         output_ The width of the image after preprocessing.
     42 
     43     Returns:
     44         resized_image: A 3-D tensor containing the resized image.
     45     """
     46     output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
     47     output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
     48 
     49     image = tf.expand_dims(image, 0)
     50     resized_image = tf.image.resize_nearest_neighbor(
     51         image, [output_height, output_width], align_corners=False)
     52     resized_image = tf.squeeze(resized_image)
     53     return resized_image
     54 
     55 def read_tfrecord(tfrecord_path, num_samples=14635, num_classes=7, resize_height=800, resize_width=800):
     56     keys_to_features = {
     57         'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string, ),
     58         'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string),
     59         'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=0),
     60         'image/height': tf.FixedLenFeature([], tf.int64, default_value=0),
     61         'image/width': tf.FixedLenFeature([], tf.int64, default_value=0)
     62     }
     63 
     64     items_to_handlers = {
     65         'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
     66         'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
     67         'height': slim.tfexample_decoder.Tensor('image/height', shape=[]),
     68         'width': slim.tfexample_decoder.Tensor('image/width', shape=[])
     69     }
     70     decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
     71 
     72     labels_to_names = None
     73     items_to_descriptions = {
     74         'image': 'An image with shape image_shape.',
     75         'label': 'A single integer between 0 and 9.'}
     76 
     77     dataset = slim.dataset.Dataset(
     78         data_sources=tfrecord_path,
     79         reader=tf.TFRecordReader,
     80         decoder=decoder,
     81         num_samples=num_samples,
     82         items_to_descriptions=None,
     83         num_classes=num_classes,
     84     )
     85     provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
     86                                                               num_readers=3,
     87                                                               shuffle=True,
     88                                                               common_queue_capacity=256,
     89                                                               common_queue_min=128,
     90                                                               seed=None)
     91     image, label, height, width = provider.get(['image', 'label', 'height', 'width'])
     92     resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width]))
     93     return resized_image, label, image, height, width
     94 
     95 if __name__ == '__main__':
     96     resized_image, label, image, height, width = read_tfrecord(tfrecord_path='train.record',
     97                                                                resize_height=800,
     98                                                                resize_width=800)
     99     # resized_image = reshape_same_size(image, FLAG.resize_height, FLAG.resize_width)
    100     # resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))
    101     print_data(image, resized_image, label, height, width)
    102 
    103     init_g = tf.global_variables_initializer()
    104     init_l = tf.local_variables_initializer()
    105     with tf.Session() as sess:
    106         sess.run(init_g)
    107         sess.run(init_l)
    108         tf.train.start_queue_runners(sess)
    109         print("SDDFA")
    110         trX = image.eval(session=sess)
    111         trY = label.eval(session=sess)
    112     print("AA")
    113     print(trX.shape)
  • 相关阅读:
    Winform中怎样去掉TextBox输入回车时的警告音
    sql server 2000 出现不能执行查询,因为一些文件丢失或未注册”
    c# winform 创建文件,把值写入文件,读取文件里的值,修改文件的值,对文件的创建,写入,修改
    <metro>PlayToReceiver class
    <metro>PlayToReceiver
    <C#>怎样学好Winform
    <C#>怎样学好winform3
    <C#>怎样学好winform4
    <metro>Application Data
    <metro>UI
  • 原文地址:https://www.cnblogs.com/smartisn/p/12438866.html
Copyright © 2011-2022 走看看