zoukankan      html  css  js  c++  java
  • TFRecord 存入图像和标签

    #-*- coding:utf-8 -*-
    import os
    import tensorflow as tf
    import cv2
    
    '''
    文件目录为
    chiwawa/
         xx.jpg
         xx.jpg
         .....
    japandog/
         xx.jpg
         xx.jpg
         .....
    '''
    cwd = 'f:/py/tfrecord/'
    classes={'chiwawa','japandog'} # 需要存入的标签,尽量与文件名一致,方便操作
    
    sess = tf.Session()
    writer = tf.python_io.TFRecordWriter("f:/py/tfrecord/train.tfrecords") # 建立一个writer
    for index, name in enumerate(classes):
        class_path = cwd + name + "/"           # 构建文件路径
        for img_name in os.listdir(class_path): # 遍历目录下的文件
            img_path = class_path + img_name     # 构建具体每一张图片的路径
            image = cv2.imread(img_path)        # 读取图片
    
            # 获取图片的宽,高和通道数
            img_w = image.shape[0]
            img_h = image.shape[1]
            img_c = image.shape[2]
    
            # tf读取图片
            img = tf.read_file(img_path)
            img = tf.image.decode_jpeg(img)
    
            # img = tf.image.resize_images(img,(224, 224)) 改变大小
            img_raw = sess.run(tf.cast(img,tf.uint8)).tostring()              #将图片转化为原生bytes
            
    
            label = name.encode('utf-8')  #将标签转化为bytes
            '''
            以下是Example类的常用固定格式,但要注意第一个features有s,对应的是tf.train.Features
            tf.train.Features里的feature是没有s的,bytes_list对应的是tf.train.BytesList,
            int64_list对应的是tf.train.Int64List,输入的value的格式也要一致,可输入的格式有int,float,bytes
            label和img_raw的格式是bytes,宽、高、通道数的格式是int
            '''
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                'img_w': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_w])),
                'img_h': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_h])),
                'img_c': tf.train.Feature(int64_list=tf.train.Int64List(value=[img_c]))
            }))
            writer.write(example.SerializeToString())  #序列化为字符串
            
    writer.close() 
  • 相关阅读:
    script标签加载顺序(defer & async)
    nginx反向代理vue访问时浏览器加载失败,出现 ERR_CONTENT_LENGTH_MISMATCH 问题
    Git每次进入都需要输入用户名和密码的问题解决
    update select
    sql --- where concat
    GO -- 正则表达式
    浏览器中回车(Enter)和刷新的区别是什么?[转载]
    转: Linux --- Supervisor的作用与配置
    Golang 使用Map构建Set类型的实现方法
    linux -- 查看应用启动时间
  • 原文地址:https://www.cnblogs.com/xinghun85/p/9119872.html
Copyright © 2011-2022 走看看