zoukankan      html  css  js  c++  java
  • Tensorflow实战图像转换成tfrecords和读取

    1、准备数据

          首选将自己的图像数据分类分别放在不同的文件夹下,比如新建data文件夹,data文件夹下分别存放up和low文件夹,up和low文件夹下存放对应的图像数据。也可以把up和low文件夹换成0和1。根据自己数据类别,自己设定。如图所示

     

     

         以上三张图片注意看目录。这样数据就准备好了。

    2、将图像数据转换成tfrecords

          直接上代码,代码中比较重要的部分我都做了注释。
     1 import os
     2 import tensorflow as tf 
     3 from PIL import Image
     4 import matplotlib.pyplot as plt
     5 import numpy as np
     6  
     7 sess=tf.InteractiveSession()
     8 cwd = "D://software//tensorflow//data//"  #数据所在目录位置
     9 classes = {'up', 'low'} #预先自己定义的类别,根据自己的需要修改
    10 writer = tf.python_io.TFRecordWriter("train.tfrecords")  #train表示转成的tfrecords数据格式的名字
    11  
    12 for index, name in enumerate(classes):
    13     class_path = cwd + name + "/"
    14     for img_name in os.listdir(class_path):
    15         img_path = class_path + img_name
    16         img = Image.open(img_path)
    17         img = img.resize((300, 300))  #图像reshape大小设置,根据自己的需要修改
    18         img_raw = img.tobytes()              
    19         example = tf.train.Example(features=tf.train.Features(feature={
    20             "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
    21             'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    22         }))
    23         writer.write(example.SerializeToString()) 
    24 writer.close()

    3、从tfrecords中读取数据

          直接上代码:

     1 #读取文件
     2 def read_and_decode(filename,batch_size):
     3     #根据文件名生成一个队列
     4     filename_queue = tf.train.string_input_producer([filename])
     5     reader = tf.TFRecordReader()
     6     _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
     7     features = tf.parse_single_example(serialized_example,
     8                                        features={
     9                                            'label': tf.FixedLenFeature([], tf.int64),
    10                                            'img_raw' : tf.FixedLenFeature([], tf.string),
    11                                        })
    12  
    13     img = tf.decode_raw(features['img_raw'], tf.uint8)
    14     img = tf.reshape(img, [300, 300, 3])                #图像归一化大小
    15    # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5   #图像减去均值处理,根据自己的需要决定要不要加上
    16     label = tf.cast(features['label'], tf.int32)        
    17  
    18     #特殊处理,去数据的batch,如果不要对数据做batch处理,也可以把下面这部分不放在函数里
    19  
    20     img_batch, label_batch = tf.train.shuffle_batch([img, label],
    21                                                     batch_size= batch_size,
    22                                                     num_threads=64,
    23                                                     capacity=200,
    24                                                     min_after_dequeue=150)
    25     return img_batch, tf.reshape(label_batch,[batch_size])

    需要注意的地方:

    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5   #图像减去均值处理,根据自己的需要决定要不要加上
    1 #特殊处理,去数据的batch,如果不要对数据做batch处理,也可以把下面这部分不放在函数里
    2     img_batch, label_batch = tf.train.shuffle_batch([img, label],
    3                                                     batch_size= batch_size,
    4                                                     num_threads=64,
    5                                                     capacity=200,
    6                                                     min_after_dequeue=150)

    如果不需要把数据做batch处理,则函数的第二个形参batch_size就去掉,函数直接返回img和label。也可以把batch处理部分放在函数外面,根据自己的需要自己修改一下。

    4、转换和读取函数的调用

    1 tfrecords_file = 'train.tfrecords'   #要读取的tfrecords文件
    2 BATCH_SIZE = 4      #batch_size的大小
    3 image_batch, label_batch = read_and_decode(tfrecords_file,BATCH_SIZE)  
    4 print(image_batch,label_batch)    #注意,这里不是tensor,tensor需要做see.run()处理   

       下面就定义session,执行即可,有一个地方需要注意,

    
    
    image_batch, label_batch = read_and_decode(tfrecords_file,BATCH_SIZE)   #需要注意

     虽然能够把数据读取出来,但是不是tensor,在训练的时候需要image,label=sess.run([image_batch,label_batch])处理后,才能投入训练。具体细节下一篇博客再做详细介绍。

     如果还有问题未能得到解决,搜索887934385交流群,进入后下载资料工具安装包等。最后,感谢观看!
  • 相关阅读:
    Jenkins系列——使用SonarQube进行代码质量检查
    HTTP1.0工作原理
    Jenkins系列——使用checkstyle进行代码规范检查
    Jenkins系列——定时构建
    Hadoop环境搭建
    eclipse3.4+对的处理插件(附SVN插件安装实例)
    MD5
    RedHat6.5更新软件源
    ubuntu软件推荐
    disconf系列【2】——解决zk部署情况为空的问题
  • 原文地址:https://www.cnblogs.com/pypypy/p/11829833.html
Copyright © 2011-2022 走看看