zoukankan      html  css  js  c++  java
  • 使用tfrecord建立自己的数据集

    注意事项:

    1.关于输入图像格式的问题

        使用io.imread()的时,根据输入图像确定as_grey的参数值。 转化为字符串之后(image.tostring) ,最后输出看下image_raw的长度。因为不同的图像编码格式,存储方式不同。

       我读入的灰度图jpeg格式,类型是int64,image_raw的大小是图像的大小的8倍 。 但如果是RGB图像,则统一类型是uint8。确定了类型,在之后的解码 (decode_raw)中,需要将type设置和存储方式同样的类型。 

       根据image_raw的长度和原图像大小,推算一下使用的类型,常用的是uint8,int32,int64.

    2.转化成tfrecords的时间有点长,需要等待。

    import os
    import tensorflow as tf
    import numpy as np
    import skimage.io as io
    import matplotlib.pyplot as plt
    import cv2
    def get_data (file_path):
        data = []
        label = []
        for dirs in os.listdir(file_path):
            temp_path = os.path.join(file_path,dirs)
            i =0
            for dirss in os.listdir(temp_path):
                data.append(os.path.join(temp_path,dirss))
            num_img = len(os.listdir(temp_path))
            label = np.append(label,num_img*[1])
        temp = np.array([data,label])
        temp = temp.transpose()
        np.random.shuffle(temp)
        image_list = list(temp[:,0])
        label_list = list(temp[:,1])
        label_list = [int(float(i)) for i in label_list]
        return image_list,label_list
    # 转化成字符串
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    def convert_tfrecord(images,labels,save_filename):
        writer = tf.python_io.TFRecordWriter(save_filename)
        print("Transform start....")
        num_examples= len(labels)
        if np.shape(images)[0]!=num_examples:
            raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], num_examples))
        for index in np.arange(0,num_examples):
            try:
                image = io.imread(images[index],as_grey=False)
                #image = tf.image.decode_jpeg(images[index])
                #print(image.shape)
                image_raw = image.tostring()
                #print(len(image_raw))
                example = tf.train.Example(features = tf.train.Features(feature={
                    'label' :_int64_feature(int(labels[index])),
                    'image_raw':_bytes_feature(image_raw)
                }))
                writer.write(example.SerializeToString())
            except IOError as e:
                print('Could not read:',images[index])
                print('error :%s Skip it !
    '%e)
        writer.close()
        print("success!")
    
    def read_and_decode(tfrecords_file,batch_size):
        reader = tf.TFRecordReader()
        filename_queue = tf.train.string_input_producer([tfrecords_file])
        _,serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                'label': tf.FixedLenFeature([],tf.int64),
                'image_raw': tf.FixedLenFeature([], tf.string)
            }
        )
        #print(features['image_raw'])
        capacity = 1000+3*batch_size
        image = tf.decode_raw(features['image_raw'],tf.uint8)
        label = tf.cast(features['label'],tf.int32)
        #image = tf.image.resize_images(image,[300, 200, 1])
        image = tf.reshape(image,[200,300,3])
        image_batch,label_batch = tf.train.batch([image,label],
                                                 batch_size=batch_size,
                                                 capacity=capacity)
        image_batch = tf.image.resize_image_with_crop_or_pad(image_batch,100,100)
        image_batch = tf.cast(image_batch, tf.float32) * (1. / 255)
        return image_batch,label_batch
    def plot_images(images, labels):
        '''plot one batch size
        '''
        for i in np.arange(0, 2):
            plt.subplot(3, 3, i + 1)
            plt.axis('off')
            # plt.title((labels[i] - 1), fontsize = 14)
            plt.subplots_adjust(top=1)
            print(labels[i])
            print(images.shape)
            # print(images[i].shape)
            plt.imshow(images[i][:,:,:])
        plt.show()
    def train():
        image,label = get_data('E:syn_data')
        convert_tfrecord(image,label,'1.tfrecords')
        x_batch, y_batch = read_and_decode('1.tfrecords', batch_size=2)
        with tf.Session() as sess:
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            try:
                i=0
                while not coord.should_stop() and i<3:
                         # just plot one batch size
                    image, label = sess.run([x_batch, y_batch])
                    plot_images(image, label)
                    i+=1
            except tf.errors.OutOfRangeError:
                print('done!')
            finally:
                coord.request_stop()
            coord.join(threads)
    
    #train()
  • 相关阅读:
    2017-2018-1 20155326 《信息安全系统设计基础》第六周课上作业
    20155326 2017-2018-1 《信息安全系统设计基础》缓冲区溢出漏洞实验
    2017-2018-1 201552326《信息安全技术》实验二——Windows口令破解
    《科技之巅2》序——机器智能数据智能:工具之王
    云大使成长精华指引(全)
    程序员职业规划课:如何开启"第二春"?
    明明可以靠脸吃饭偏要靠才华_你身边有女神程序员吗?
    6月19日云栖精选夜读:血泪总结!创业公司CTO要避免哪些坑?
    玩过这些经典单机游戏_就说明你已经老了
    帮程序员减压放松的10个良心网站
  • 原文地址:https://www.cnblogs.com/jzcbest1016/p/8059197.html
Copyright © 2011-2022 走看看