zoukankan      html  css  js  c++  java
  • Tensorflow创建和读取17flowers数据集

    http://blog.csdn.net/sinat_16823063/article/details/53946549

    Tensorflow创建和读取17flowers数据集

    标签: tensorflow
     分类:
        近期开始学习tensorflow,看了很多视频教程以及博客,大多数前辈在介绍tensorflow的用法时都会调用官方文档里给出的数据集,但是对于我这样的小白来说,如果想训练自己的数据集,自己将图片转换成可以输入到网络中的格式确实是有难度。但如果不会做图片的预处理,迈不出这一步,今后的学习之路会越来越难走,所以今天还是硬着头皮把我这几天已经实现的部分做一个总结。主要参考了一篇博客,文章最后有链接,通过这位博主的方法我成功生成了自己的数据集。
        首先,介绍一下用到的两个库,一个是os,一个是PIL。PIL(Python Imaging Library)是 Python 中最常用的图像处理库,而Image类又是 PIL库中一个非常重要的类,通过这个类来创建实例可以有直接载入图像文件,读取处理过的图像和通过抓取的方法得到的图像这三种方法。
        我采用的数据集是17 Category Flower Dataset。17flowers是牛津大学Visual Geometry Group选取的在英国比较常见的17种花。其中每种花有80张图片,整个数据及有1360张图片,可以在官网下载。不过在后续的训练过程中遇到了过拟合的问题,稍后再解释。
        由于17-flower数据集的结构如下图所示,标签就是最外层的文件夹的名字。所以在输入标签的时候可以直接通过文件读取的方式。
     
        我们是通过TFRecords来创建数据集的,TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(label)。
    [python] view plain copy
     
    1. import os  
    2. import tensorflow as tf  
    3. from PIL import Image  
    4.   
    5. cwd = os.getcwd()  
    6. classes = os.listdir(cwd+"/17flowers/jpg")  
    7.   
    8. writer = tf.python_io.TFRecordWriter("train.tfrecords")  
    9. for index, name in enumerate(classes):  
    10.     class_path = cwd + "/17flowers/jpg/" + name + "/"  
    11.     if os.path.isdir(class_path):  
    12.         for img_name in os.listdir(class_path):  
    13.             img_path = class_path + img_name  
    14.             img = Image.open(img_path)  
    15.             img = img.resize((224, 224))  
    16.             img_raw = img.tobytes()              #将图片转化为原生bytes  
    17.             example = tf.train.Example(features=tf.train.Features(feature={  
    18.             "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])),  
    19.             'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))  
    20.         }))  
    21.             writer.write(example.SerializeToString())  #序列化为字符串  
    22.             writer.close()  
    23.             print(img_name)  

        我们使用tf.train.Example来定义我们要填入的数据格式,其中label即为标签,也就是最外层的文件夹名字,img_raw为易经理二进制化的图片。然后使用tf.python_io.TFRecordWriter来写入。基本的,一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List。就这样,我们把相关的信息都存到了一个文件中,所以前面才说不用单独的label文件。而且读取也很方便。

        下面测试一下,已经存好的训练集是否可用:

    [python] view plain copy
     
    1. for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):  
    2.     example = tf.train.Example()  
    3.     example.ParseFromString(serialized_example)  
    4.   
    5.     image = example.features.feature['image'].bytes_list.value  
    6.     label = example.features.feature['label'].int64_list.value  
    7.     # 可以做一些预处理之类的  
    8.     print image, label  

        可以输出值,那么现在我们创建好的数据集已经存储在了统计目录下的train.tfrecords中了。接下来任务就是通过队列(queue)来读取这个训练集中的数据。

    [python] view plain copy
     
    1. def read_and_decode(filename):
      
    2.   #根据文件名生成一个队列
      
    3.   filename_queue = tf.train.string_input_producer([filename])

      
    4.   reader = tf.TFRecordReader()
      
    5.   _, serialized_example = reader.read(filename_queue)     
    6.   #返回文件名和文件
      
    7.   features = tf.parse_single_example(serialized_example,
features={
       
    8.                                                'label': tf.FixedLenFeature([], tf.int64),
                                                                    'img_raw' : tf.FixedLenFeature([], tf.string),
})

      
    9.   img = tf.decode_raw(features['img_raw'], tf.uint8)
      
    10.   img = tf.reshape(img, [224, 224, 3])
      
    11.   img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
      
    12.   label = tf.cast(features['label'], tf.int64)

      
    13.   return img, label  
    其中的filename,即刚刚通过TFReader来生成的训练集。通过将其转化成string类型数据,再通过reader来读取队列中的文件,并通过features的名字,‘label’和‘img_raw’来得到对应的标签和图片数据。之后就是一系列的转码和reshape的工作了。
        准备好了这些训练集,接下来就是利用得到的label和img进行网络的训练了。
    [python] view plain copy
     
    1. img, label = read_and_decode("train.tfrecords")
  
    2. img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=100, capacity=2000,
 min_after_dequeue=1000)
  
    3. labels = tf.one_hot(label_batch,17,1,0)  
    4.  
coord = tf.train.Coordinator()
  
    5.  threads = tf.train.start_queue_runners(coord=coord,sess=sess)  
    6.  

for i in range(200):
      
    7.    batch_xs, batch_ys = sess.run([img_batch, labels])
      
    8.    print(sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5}))
      
    9.    print("Loss:", sess.run(cross_entropy,feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5}))
       
    10.    if i % 50 == 0:
           
    11.      print(compute_accuracy(mnist.test.images, mnist.test.labels))
  
    12.   
    13.  coord.request_stop()  
    14.  
coord.join()  

    注意一点,由于这里使用了队列的方式来进行训练集的读取,所以异步方式,通过Coordinator让queue runner通过coordinator来启动这些线程,并在最后读取队列结束后终止线程。
        不过,在训练这个训练集的过程中不断的输出loss函数值,发现只迭代了5次就为0了,目前想到的原因可能是训练集太小,每个类只有80张图片。另一个原因可能是网络结构太深,由于使用了VGGNet,训练参数太多,容易过拟合。下次做个小规模的网络测试一下。
    delphi lazarus opengl 网页操作自动化, 图像分析破解,游戏开发
  • 相关阅读:
    drop table 、delete table和truncate table的区别
    润乾报表 删除导出excel弹出框里的选项
    学习笔记: 委托解析和封装,事件及应用
    学习笔记: 特性Attribute详解,应用封装
    学习笔记: 反射应用、原理,完成扩展,emit动态代码
    学习笔记: 泛型应用、原理、协变逆变、泛型缓存
    jmeter4.x centos7部署笔记
    rabbitmq3.7.5 centos7 集群部署笔记
    rabbitmq3.8.3 centos7 安装笔记
    UVA-12436 Rip Van Winkle's Code (线段树区间更新)
  • 原文地址:https://www.cnblogs.com/delphi-xe5/p/7001054.html
Copyright © 2011-2022 走看看