zoukankan      html  css  js  c++  java
  • TensorFlowIO操作(三)------图像操作

    图像操作

    图像基本概念

    在图像数字化表示当中,分为黑白和彩色两种。在数字化表示图片的时候,有三个因素。分别是图片的长、图片的宽、图片的颜色通道数。那么黑白图片的颜色通道数为1,它只需要一个数字就可以表示一个像素位;而彩色照片就不一样了,它有三个颜色通道,分别为RGB,通过三个数字表示一个像素位。TensorFlow支持JPG、PNG图像格式,RGB、RGBA颜色空间。图像用与图像尺寸相同(heightwidthchnanel)张量表示。图像所有像素存在磁盘文件,需要被加载到内存。

    图像大小压缩

    大尺寸图像输入占用大量系统内存。训练CNN需要大量时间,加载大文件增加更多训练时间,也难存放多数系统GPU显存。大尺寸图像大量无关本征属性信息,影响模型泛化能力。最好在预处理阶段完成图像操作,缩小、裁剪、缩放、灰度调整等。图像加载后,翻转、扭曲,使输入网络训练信息多样化,缓解过拟合。Python图像处理框架PIL、OpenCV。TensorFlow提供部分图像处理方法。

    • tf.image.resize_images 压缩图片导致定大小

    图像数据读取实例

    同样图像加载与二进制文件相同。图像需要解码。输入生成器(tf.train.string_input_producer)找到所需文件,加载到队列。tf.WholeFileReader 加载完整图像文件到内存,WholeFileReader.read 读取图像,tf.image.decode_jpeg 解码JPEG格式图像。图像是三阶张量。RGB值是一阶张量。加载图像格 式为[batch_size,image_height,image_width,channels]。批数据图像过大过多,占用内存过高,系统会停止响应。直接加载TFRecord文件,可以节省训练时间。支持写入多个样本。

    读取图片数据到Tensor

    管道读端多文件内容处理

    但是会发现read只返回一个图片的值。所以我们在之前处理文件的整个流程中,后面的内容队列的出队列需要用特定函数去获取。

    • tf.train.batch 读取指定大小(个数)的张量
    • tf.train.shuffle_batch 乱序读取指定大小(个数)的张量
    def readpic_decode(file_list):
        """
        批量读取图片并转换成张量格式
        :param file_list: 文件名目录列表
        :return: None
        """
    
        # 构造文件队列
        file_queue = tf.train.string_input_producer(file_list)
    
        # 图片阅读器和读取数据
        reader = tf.WholeFileReader()
        key,value = reader.read(file_queue)
    
        # 解码成张量形式
    
        image_first = tf.image.decode_jpeg(value)
    
        print(image_first)
    
        # 缩小图片到指定长宽,不用指定通道数
        image = tf.image.resize_images(image_first,[256,256])
    
        # 设置图片的静态形状
        image.set_shape([256,256,3])
    
        print(image)
    
        # 批处理图片数据,tensors是需要具体的形状大小
        image_batch = tf.train.batch([image],batch_size=100,num_threads=1,capacity=100)
    
        tf.summary.image("pic",image_batch)
    
        with tf.Session() as sess:
    
            merged = tf.summary.merge_all()
    
            filewriter = tf.summary.FileWriter("/tmp/summary/dog/",graph=sess.graph)
    
            # 线程协调器
            coord = tf.train.Coordinator()
    
            # 开启线程
            threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    
            print(sess.run(image_batch))
    
            summary = sess.run(merged)
    
            filewriter.add_summary(summary)
    
            # 等待线程回收
            coord.request_stop()
            coord.join(threads)
    
    
        return None
    
    
    if __name__=="__main__":
    
        # 获取文件列表
        filename = os.listdir("./dog/")
    
        # 组合文件目录和文件名
        file_list = [os.path.join("./dog/",file) for file in filename]
    
        # 调用读取函数
        readpic_decode(file_list)

    读取TfRecords文件数据

    #CIFAR-10的数据读取以及转换成TFRecordsg格式
    
    #1、数据的读取
    
    FLAGS = tf.app.flags.FLAGS
    
    tf.app.flags.DEFINE_string("data_dir","./cifar10/cifar-10-batches-bin/","CIFAR数据目录")
    tf.app.flags.DEFINE_integer("batch_size",50000,"样本个数")
    tf.app.flags.DEFINE_string("records_file","./cifar10/cifar.tfrecords","tfrecords文件位置")
    
    class CifarRead(object):
    
        def __init__(self,filename):
            self.filelist = filename
    
            # 定义图片的长、宽、深度,标签字节,图像字节,总字节数
            self.height = 32
            self.width = 32
            self.depth = 3
            self.label_bytes = 1
            self.image_bytes = self.height*self.width*self.depth
            self.bytes = self.label_bytes + self.image_bytes
    
    
        def readcifar_decode(self):
            """
            读取数据,进行转换
            :return: 批处理的图片和标签
            """
    
            # 1、构造文件队列
            file_queue = tf.train.string_input_producer(self.filelist)
    
            # 2、构造读取器,读取内容
            reader = tf.FixedLengthRecordReader(self.bytes)
    
            key,value = reader.read(file_queue)
    
            # 3、文件内容解码
            image_label = tf.decode_raw(value,tf.uint8)
    
            # 分割标签与图像张量,转换成相应的格式
    
            label = tf.cast(tf.slice(image_label,[0],[self.label_bytes]),tf.int32)
    
            image = tf.slice(image_label,[self.label_bytes],[self.image_bytes])
    
            print(image)
    
            # 给image设置形状,防止批处理出错
            image_tensor = tf.reshape(image,[self.height,self.width,self.depth])
    
            print(image_tensor.eval())
            # depth_major = tf.reshape(image, [self.depth,self.height, self.width])
            # image_tensor = tf.transpose(depth_major, [1, 2, 0])
    
            # 4、处理流程
            image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)
    
    
            return image_batch,label_batch
    
    
        def convert_to_tfrecords(self,image_batch,label_batch):
            """
            转换成TFRecords文件
            :param image_batch: 图片数据Tensor
            :param label_batch: 标签数据Tensor
            :param sess: 会话
            :return: None
            """
    
            # 创建一个TFRecord存储器
            writer = tf.python_io.TFRecordWriter(FLAGS.records_file)
    
            # 构造每个样本的Example
            for i in range(10):
                print("---------")
                image = image_batch[i]
                # 将单个图片张量转换为字符串,以可以存进二进制文件
                image_string = image.eval().tostring()
    
                # 使用eval需要注意的是,必须存在会话上下文环境
                label = int(label_batch[i].eval()[0])
    
                # 构造协议块
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                })
                )
    
                # 写进文件
                writer.write(example.SerializeToString())
    
            writer.close()
    
            return None
    
        def read_from_tfrecords(self):
            """
            读取tfrecords
            :return: None
            """
            file_queue = tf.train.string_input_producer(["./cifar10/cifar.tfrecords"])
    
            reader = tf.TFRecordReader()
    
            key, value = reader.read(file_queue)
    
            features = tf.parse_single_example(value, features={
                "image":tf.FixedLenFeature([], tf.string),
                "label":tf.FixedLenFeature([], tf.int64),
            })
    
            image = tf.decode_raw(features["image"], tf.uint8)
    
            # 设置静态形状,可用于转换动态形状
            image.set_shape([self.image_bytes])
    
            print(image)
    
            image_tensor = tf.reshape(image,[self.height,self.width,self.depth])
    
            print(image_tensor)
    
            label = tf.cast(features["label"], tf.int32)
    
            print(label)
    
            image_batch, label_batch = tf.train.batch([image_tensor, label],batch_size=10,num_threads=1,capacity=10)
            print(image_batch)
            print(label_batch)
    
            with tf.Session() as sess:
                coord = tf.train.Coordinator()
    
                threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    
                print(sess.run([image_batch, label_batch]))
    
                coord.request_stop()
                coord.join(threads)
    
            return None
    
    
    if __name__=="__main__":
        # 构造文件名字的列表
        filename = os.listdir(FLAGS.data_dir)
        file_list = [os.path.join(FLAGS.data_dir, file) for file in filename if file[-3:] == "bin"]
    
        cfar = CifarRead(file_list)
        # image_batch,label_batch = cfar.readcifar_decode()
        cfar.read_from_tfrecords()
    
        with tf.Session() as sess:
    
    
            # 构建线程协调器
            coord = tf.train.Coordinator()
    
            # 开启线程
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
            # print(sess.run(image_batch))
    
            # 存进文件
            # cfar.convert_to_tfrecords(image_batch, label_batch)
    
    
            coord.request_stop()
            coord.join(threads)
  • 相关阅读:
    关于spring security的若干事情
    .net2005 datagridview 如何获取值改变的单元格的集合??(小弟没有为datagridview添加数据源,也就是说单元格中的数据是手工录入的)
    关于做一个通用打印类的设想,大家谈谈看法
    请教C#,两个类中的变量互访问题
    刚发现了一个问题,关于vs2005 datagridview的,我发现在设计行标头的HeaderCell.Value的时候要是设置RowTemplate.Height 的值>= 17则行标头的那个黑三角就显示出来了,要是小于17就不能显示了,想问问大家,是怎么回事?
    软件架构模式基本概念及三者区别
    以英雄联盟的方式建模,谈对依赖注入(DI)的理解以及Autofac的用法(一)
    适配器模式
    [翻译] WCF运行时架构
    关于synchronized 影响可见性的问题
  • 原文地址:https://www.cnblogs.com/fwl8888/p/9794526.html
Copyright © 2011-2022 走看看