zoukankan      html  css  js  c++  java
  • 拔草,给训练集验证集和测试集的多类图像制作tfrecords文件

    明天吧再来一个读取tfrecords的,昨天做的时候遇到了问题,电脑不行,老显示一些库函数不存在,其实库已经导入进去了,但是python就是这样,所以还没入坑的小伙伴去学caffe吧。不要被python毒害了。把代码粘上,有几个函数是没有用的,看之前大神的帖子上的,他做了好多函数来测试他的records有没有做成功,就是厉害,大神就是大神。

    import tensorflow as tf
    import numpy as np
    import os
    import random
    from PIL import Image
    
    def _int64_feature(label):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    
    def _bytes_feature(imgdir):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgdir]))
    
    def float_list_feature(value):
      return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    
    def get_example_nums(tf_records_filenames):
        nums= 0
        for record in tf.python_io.tf_record_iterator(tf_records_filenames):
            nums += 1
        return nums
    
    def get_example_num(records_file_dir):
        nums=0
        for record in tf.io.tf_record_iterator(records_file_dir):
            nums+=1
        return nums
    
    def load_file(imagestxtdir,shuffle=False):
        images=[]#存储各个集中图像地址的列表
        labels=[]
        with open(imagestxtdir) as f:
            lines_list=f.readlines()#读取文件列表中所有的行
            if shuffle:
                random.shuffle(lines_list)#将图像库中的图像地址进行随机的打乱
            for line in lines_list:
                line_list=line.rstrip().split(' ')#rstrip函数是将每一行首尾的空白都去除然后
                label=[]
                for i in range(1):
                    label.append(int(line_list[i+1]))
                #cur_img_dir=images_base_dir+'/'+line_list[0]
                images.append(line_list[0])
                labels.append(label)
        return images,labels
    
    def get_batch_images(images,labels,batch_size,labels_num,one_hot=False,shuffle=False,num_threads=1):
        min_after_dequeue=200
        capacity=min_after_dequeue+3*batch_size
        if shuffle:
            images_batch,labels_batch=tf.train.shuffle_batch([images,labels],
                                                             batch_size=batch_size,
                                                             capacity=capacity,
                                                             min_after_dequeue=min_after_dequeue,
                                                             num_threads=num_threads)
        else:
            images_batch,labels_batch=tf.train.batch([images,labels],
                                                     batch_size=batch_size,
                                                     num_threads=num_threads,
                                                     capacity=capacity)
        if one_hot:
            labels_batch=tf.one_hot(labels_batch,labels_num,1,0)
        return images_batch,labels_batch
    
    
    def create_tf_records(image_base_dir,image_txt_dir,tfrecords_dir,resise_height,resize_weight,shuffle,log=5):
        images_list,labels_list=load_file(image_txt_dir,shuffle)
        writer=tf.io.TFRecordWriter(tfrecords_dir)
        for i,[image_name,single_label_list] in enumerate(zip(images_list,labels_list)):
            cur_image_dir=image_base_dir+'/'+images_list[i]
            if not os.path.exists(cur_image_dir):
                print('the image path is not exists')
                continue
            image=Image.open(cur_image_dir)
            image=image.resize((resise_height,resize_weight))
            image_raw=image.tobytes()
            single_label=single_label_list[0]
            if i % log == 0 or i == len(images_list) - 1:
                print('------------processing:%d-th------------' % (i))
            example=tf.train.Example(features=tf.train.Features(feature={
                'image_raw':_bytes_feature(image_raw),
                'label':_int64_feature(single_label)
            }))
            writer.write(example.SerializeToString())
        writer.close()
    
    
    
    
    if __name__=='__main__':
        resize_height=224
        resize_width=224
        shuffle=True
        log=5
    
        train_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/train'
        train_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/train.txt'
        train_records_dir='D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords'
        create_tf_records(train_image_dir,train_txt_dir,train_records_dir,resize_height,resize_width,shuffle,log)
        train_nums=get_example_nums(train_records_dir)
        print('the train records number is:',train_nums)
    
        validation_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/validation'
        validation_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/validation.txt'
        validation_records_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/validation.tfrecords'
        create_tf_records(validation_image_dir,validation_txt_dir,validation_records_dir,resize_height, resize_width, shuffle, log)
        validation_nums = get_example_nums(validation_records_dir)
        print('the validation records number is:', validation_nums)
    
        test_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/test'
        test_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/test.txt'
        test_records_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/test.tfrecords'
        create_tf_records(test_image_dir, test_txt_dir, test_records_dir, resize_height, resize_width, shuffle, log)
        test_nums = get_example_nums(test_records_dir)
        print('the test records number is:', test_nums)

    这个是我自己电脑的环境,就五类图像,如果你想做很多类的也行,都是一个道理,改一下路径就可以,注释懒得写了,因为代码写得比较简单,哈哈哈,想转的话随便转,但是真正想学的人还是得自己敲,但是我的博客写的很一般我估计没有人看应该

  • 相关阅读:
    C#调用自定义表类型参数
    不同版本SQL SERVER备份还原时造成索引被禁用
    SQL SERVER同步环境新增发布对象时不能生成(sp_MS+表名)同步存储过程
    C# 读取在存储过程多结果集
    C#读取XML文件
    批量还原V2
    tmux 常用快捷键
    无法生成SSPI上下文
    sql server 性能计数器
    sql server 2008 r2 xevent
  • 原文地址:https://www.cnblogs.com/daremosiranaihana/p/11429560.html
Copyright © 2011-2022 走看看