zoukankan      html  css  js  c++  java
  • 制作数据集-解析篇

    1、数据集生成读取文件(mnist_generateds.py)
    tfrecords 文件
    1)tfrecords:是一种二进制文件,可先将图片和标签制作成该格式的文件。使用 tfrecords 进行数据读取,会提高内存利用率。 
    2)tf.train.Example: 用来存储训练数据。训练数据的特征用键值对的形式表示。如:‘ img_raw ’ :值 ‘ label ’ :值 值是 Byteslist/FloatList/Int64List 
    3)SerializeToString( ):把数据序列化成字符串存储。
    首先生成 tfrecords 文件 :

    1)将数据集的相关路径定义好

    2)读训练集和测试集;

    读文件解析

    a:先读入文件名,路径

    b:新建一个writer,计数次数

    c: 在open函数中默认为只读形式打开label_path,readlines() 方法用于读取所有行(直到结束符 EOF)并返回列表,该列表可以由 Python 的 for... in ... 结构进行处理。如果碰到结束符 EOF 则返回空字符串。

    d:for 循环遍历每张图和标签

    f:

      example = tf.train.Example(features=tf.train.Features(feature={
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
                    })) 
    这段代码将在tf.train.examle函数中讲解
    e: writer.write(example.SerializeToString())   # 把 example 进行序列化 
    image_train_path='./data/mnist_data_jpg/mnist_train_jpg_60000/'
    label_train_path='./data/mnist_data_jpg/mnist_train_jpg_60000.txt'
    tfRecord_train='./data/mnist_train.tfrecords'
    image_test_path='./data/mnist_data_jpg/mnist_test_jpg_10000/'
    label_test_path='./data/mnist_data_jpg/mnist_test_jpg_10000.txt'
    tfRecord_test='./data/mnist_test.tfrecords'
    data_path='./data'
    resize_height = 28
    resize_width = 28
    
    #生成tfrecords文件
    def write_tfRecord(tfRecordName, image_path, label_path):
        #新建一个writer
        writer = tf.python_io.TFRecordWriter(tfRecordName)  
        num_pic = 0 
        f = open(label_path, 'r')
        contents = f.readlines()
        f.close()
        #循环遍历每张图和标签 
        for content in contents:
            value = content.split()
            img_path = image_path + value[0] 
            img = Image.open(img_path)
            img_raw = img.tobytes() 
            labels = [0] * 10  
            labels[int(value[1])] = 1  
            #把每张图片和标签封装到example中    
            example = tf.train.Example(features=tf.train.Features(feature={
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
                    })) 
            #把example进行序列化
            writer.write(example.SerializeToString())
            num_pic += 1 
            print ("the number of picture:", num_pic)
        #关闭writer
        writer.close()
        print("write tfrecord successful")
    
    def generate_tfRecord():
        isExists = os.path.exists(data_path) 
        if not isExists: 
             os.makedirs(data_path)
            print 'The directory was created successfully'
        else:
            print 'directory already exists' 
        write_tfRecord(tfRecord_train, image_train_path, label_train_path)
         write_tfRecord(tfRecord_test, image_test_path, label_test_path)

    解析tfrecords文件:

    先看一下对应的路径文件名:

    image_train_path='./data/mnist_data_jpg/mnist_train_jpg_60000/'
    label_train_path='./data/mnist_data_jpg/mnist_train_jpg_60000.txt'
    tfRecord_train='./data/mnist_train.tfrecords'
    image_test_path='./data/mnist_data_jpg/mnist_test_jpg_10000/'
    label_test_path='./data/mnist_data_jpg/mnist_test_jpg_10000.txt'
    tfRecord_test='./data/mnist_test.tfrecords'
    data_path='./data'
    def main():
        generate_tfRecord()                                  (1)

    def generate_tfRecord():                                 (2)
        isExists = os.path.exists(data_path) 
        if not isExists: 
             os.makedirs(data_path)
            print 'The directory was created successfully'
        else:
            print 'directory already exists'  
    def get_tfrecord(num, isTrain=True):                                                   (1)                                      
    if isTrain:
    tfRecord_path = tfRecord_train

    else:
    tfRecord_path = tfRecord_test

    img, label = read_tfRecord(tfRecord_path) (2)
    def read_tfRecord(tfRecord_path):                                                       (3)
    filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
        #新建一个reader
     reader = tf.TFRecordReader()

    上面用颜色标记了一些参数的传递情况;
    下面从主函数谈起:
    main函数调用了generate_tfRecord()函数,在该函数中使用write_tfRecord()写入相关数据集到tfRecord中;
           get_tfrecord函数获取被写入的数据集的tfRecord的地址tfRecord_path = tfRecord_train (tfRecord_path = tfRecord_test ),使用read_tfRecord()函数读取信息。
    在read_tfRecord()函数中;

              1)filename_queue = tf.train.string_input_producer([tfRecord_path])
                                     tf.train.string_input_producer( string_tensor,
                                                                                     num_epochs=None,
                                                                                     shuffle=True,
                                                                                     seed=None,
                                                                                    capacity=32,
                                                                                    shared_name=None,
                                                                                    name=None,
                                                                                    cancel_op=None)
           该函数会生成一个先入先出的队列,文件阅读器会使用它来读取数据。
           参数说明:string_tensor: 存储图像和标签信息的 TFRecord 文件名列表
                             num_epochs: 循环读取的轮数(可选)
                             shuffle:布尔值(可选),如果为 True,则在每轮随机打乱读取顺序
                             seed:随机读取时设置的种子(可选)
                             capacity:设置队列容量
                             shared_name:(可选) 如果设置,该队列将在多个会话中以给定名称共享。所有具有此队列的设备都可以通过 shared_name 访问它。在分布式设置中使用这种方法意味着每

    个名称只能被访问此操作的其中一个会话看到。 
                            name:操作的名称(可选)
                            cancel_op:取消队列(None)
          2)reader = tf.TFRecordReader() #新建一个 reader
          3)_, serialized_example = reader.read(filename_queue)
              features = tf.parse_single_example(serialized_example,features={
                                                                                              'img_raw': tf.FixedLenFeature([ ], tf.string) ,
                                                                                              'label': tf.FixedLenFeature([10], tf.int64)})
                              #把读出的每个样本保存在 serialized_example 中进行解序列化,标签和图片的键名应该和制作 tfrecords 的键名相同,其中标签给出几分类。 
             tf.parse_single_example(serialized,
                                                     features,
                                                     name=None,
                                                     example_names=None)
               该函数可以将 tf.train.Example 协议内存块(protocol buffer)解析为张量。
                      参数说明:serialized: 一个标量字符串张量
                                        features: 一个字典映射功能键 FixedLenFeature 或 VarLenFeature值,也就是在协议内存块中储存的 
                                        name:操作的名称(可选)
                                        example_names: 标量字符串联的名称(可选)
            4)img = tf.decode_raw(features['img_raw'], tf.uint8)
                             #将 img_raw 字符串转换为 8 位无符号整型
           5)img.set_shape([784]) #将形状变为一行 784 列
           6)img = tf.cast(img, tf.float32) * (1. / 255) #变成 0 到 1 之间的浮点数
           7)label = tf.cast(features['label'], tf.float32)#把标签列表变为浮点数
           8)return image,label #返回图片和标签(跳回到 get_tfrecord)
           9) tf.train.shuffle_batch( tensors,
                                                  batch_size,
                                                  capacity,
                                                  min_after_dequeue,
                                                  num_threads=1,
                                                  seed=None,
                                                  enqueue_many=False,
                                                  shapes=None,
                                                  allow_smaller_final_batch=False,
                                                  shared_name=None,
                                                 name=None)
        这个函数随机读取一个 batch 的数据。
                     参数说明:tensors: 待乱序处理的列表中的样本(图像和标签)
                                        batch_size: 从队列中提取的新批量大小
                                       capacity:队列中元素的最大数量
                                       min_after_dequeue: 出队后队列中的最小数量元素,用于确保元素的混合级别
                                       num_threads: 排列 tensors 的线程数
                                       seed:用于队列内的随机洗牌
                                      enqueue_many: tensor 中的每个张量是否是一个例子
                                      shapes: 每个示例的形状
                                      allow_smaller_final_batch: (可选)布尔值。 如果为 True,则在队列中剩余数量不足时允许最终批次更小。 
                                      shared_name:(可选)如果设置,该队列将在多个会话中以给定名称共享。 
                                      name:操作的名称(可选)
              10)return img_batch,label_batch             #返回的图片和标签为随机抽取的 batch_size 组
    2.反向传播文件修改图片标签获取的接口(mnist_backward.py) 

    关键操作:利用多线程提高图片和标签的批获取效率       方法:将批获取的操作放到线程协调器开启和关闭之间 

    开启线程协调器:
         coord = tf.train.Coordinator( )
         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    关闭线程协调器:
         coord.request_stop( )
         coord.join(threads)
    注解:
         tf.train.start_queue_runners( sess=None,
         coord=None,
        daemon=True,
        start=True,
        collection=tf.GraphKeys.QUEUE_RUNNERS)
    这个函数将会启动输入队列的线程,填充训练样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个 tf.train.Coordinator ,这

    样可以在发生错误的情况下正确地关闭这些线程。

    参数说明:sess:用于运行队列操作的会话。 默认为默认会话。
                      coord:可选协调器,用于协调启动的线程。
                       daemon: 守护进程,线程是否应该标记为守护进程,这意味着它们不会阻止程序退出。 
                       start:设置为 False 只创建线程,不启动它们。
                      collection:指定图集合以获取启动队列的 GraphKey。默认为GraphKeys.QUEUE_RUNNERS。

    与之前学习代码的区别

    1)TEST_NUM=10000
    之前:用 mnist.test.num_examples 表示总样本数;
    现在:要手动给出测试的总样本数,这个数是 1 万。
    2)image_batch, label_batch=mnist_generateds.get_tfrecord(TEST_NUM, isTrain=False) 
    之前:用 mnist.test.next_batch 函数读出图片和标签喂给网络;
    现在:用函数 get_tfrecord 替换读取所有测试集 1 万张图片。
    isTrain:用来区分训练阶段和测试阶段,True 表示训练,False 表示测试。
    3)xs,ys=sess.run([img_batch,label_batch])
    之前:使用函数 xs,ys=mnist.test.next_batch(BATCH_SIZE)
    现在:在 sess.run 中执行图片和标签的批获取。



  • 相关阅读:
    html-webpack-plugin & clean-webpack-plugin
    Using webpack-dev-server
    Using Watch Mode
    webpack中devtool的配置方案[开发模式]---[线上模式]
    linux命令系列-mv(移动-重命名)
    洗牌函数[打乱数组的顺序] slice()的新运用 [原来arr.slice(start, end) 的start不是必需的]
    Currency Exchange (POJ1860)(判断正圈)(spfa) 最短路专题
    PTA L3-020 至多删三个字符 (DP) (天梯赛训练)
    Codeforces Round #658 (Div. 2)(D. Unmerge)
    Codeforces Round #656 (Div. 3) (E. Directing Edges)(拓扑排序)
  • 原文地址:https://www.cnblogs.com/fcfc940503/p/11019441.html
Copyright © 2011-2022 走看看