zoukankan      html  css  js  c++  java
  • tensorflow tfrecoder read write

      1 #  write in tfrecord
      2 import tensorflow as tf
      3 import os
      4 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
      5 
      6 
      7 FLAGS = tf.app.flags.FLAGS
      8 tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "验证码tfrecords文件")
      9 tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "验证码图片路径")
     10 tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "验证码字符的种类")
     11 
     12 
     13 def dealwithlabel(label_str):
     14 
     15     # 构建字符索引 {0:'A', 1:'B'......}
     16     num_letter = dict(enumerate(list(FLAGS.letter)))
     17 
     18     # 键值对反转 {'A':0, 'B':1......}
     19     letter_num = dict(zip(num_letter.values(), num_letter.keys()))
     20 
     21     print(letter_num)
     22 
     23     # 构建标签的列表
     24     array = []
     25 
     26     # 给标签数据进行处理[[b"NZPP"]......]
     27     for string in label_str:
     28 
     29         letter_list = []# [1,2,3,4]
     30 
     31         # 修改编码,bytes --> string
     32         for letter in string.decode('utf-8'):
     33             letter_list.append(letter_num[letter])
     34 
     35         array.append(letter_list)
     36 
     37     # [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], [16, 6, 13, 10], [1, 0, 8, 17], [0, 9, 24, 14].....]
     38     print(array)
     39 
     40     # 将array转换成tensor类型
     41     label = tf.constant(array)
     42 
     43     return label
     44 
     45 
     46 def get_captcha_image():
     47     """
     48     获取验证码图片数据
     49     :param file_list: 路径+文件名列表
     50     :return: image
     51     """
     52     # 构造文件名
     53     filename = []
     54 
     55     for i in range(6000):
     56         string = str(i) + ".jpg"
     57         filename.append(string)
     58 
     59     # 构造路径+文件
     60     file_list = [os.path.join(FLAGS.captcha_dir, file) for file in filename]
     61 
     62     # 构造文件队列
     63     file_queue = tf.train.string_input_producer(file_list, shuffle=False)
     64 
     65     # 构造阅读器
     66     reader = tf.WholeFileReader()
     67 
     68     # 读取图片数据内容
     69     key, value = reader.read(file_queue)
     70 
     71     # 解码图片数据
     72     image = tf.image.decode_jpeg(value)
     73 
     74     image.set_shape([20, 80, 3])
     75 
     76     # 批处理数据 [6000, 20, 80, 3]
     77     image_batch = tf.train.batch([image], batch_size=6000, num_threads=1, capacity=6000)
     78 
     79     return image_batch
     80 
     81 
     82 def get_captcha_label():
     83     """
     84     读取验证码图片标签数据
     85     :return: label
     86     """
     87     file_queue = tf.train.string_input_producer(["../data/Genpics/labels.csv"], shuffle=False)
     88 
     89     reader = tf.TextLineReader()
     90 
     91     key, value = reader.read(file_queue)
     92 
     93     records = [[1], ["None"]]
     94 
     95     number, label = tf.decode_csv(value, record_defaults=records)
     96 
     97     # [["NZPP"], ["WKHK"], ["ASDY"]]
     98     label_batch = tf.train.batch([label], batch_size=6000, num_threads=1, capacity=6000)
     99 
    100     return label_batch
    101 
    102 
    103 def write_to_tfrecords(image_batch, label_batch):
    104     """
    105     将图片内容和标签写入到tfrecords文件当中
    106     :param image_batch: 特征值
    107     :param label_batch: 标签纸
    108     :return: None
    109     """
    110     # 转换类型
    111     label_batch = tf.cast(label_batch, tf.uint8)
    112 
    113     print(label_batch)
    114 
    115     # 建立TFRecords 存储器
    116     writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)
    117 
    118     # 循环将每一个图片上的数据构造example协议块,序列化后写入
    119     for i in range(6000):
    120         # 取出第i个图片数据,转换相应类型,图片的特征值要转换成字符串形式
    121         image_string = image_batch[i].eval().tostring()
    122 
    123         # 标签值,转换成整型
    124         label_string = label_batch[i].eval().tostring()
    125 
    126         # 构造协议块
    127         example = tf.train.Example(features=tf.train.Features(feature={
    128             "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
    129             "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string]))
    130         }))
    131 
    132         writer.write(example.SerializeToString())
    133 
    134     # 关闭文件
    135     writer.close()
    136 
    137     return None
    138 
    139 
    140 if __name__ == "__main__":
    141 
    142     # 获取验证码文件当中的图片
    143     image_batch = get_captcha_image()
    144 
    145     # 获取验证码文件当中的标签数据
    146     label = get_captcha_label()
    147 
    148     print(image_batch, label)
    149 
    150     with tf.Session() as sess:
    151 
    152         coord = tf.train.Coordinator()
    153 
    154         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    155 
    156         # 获取tensor里面的值
    157         label_str = sess.run(label)
    158 
    159         print(label_str)
    160 
    161         # 处理字符串标签到数字张量
    162         label_batch = dealwithlabel(label_str)
    163 
    164         print(label_batch)
    165 
    166         # 将图片数据和内容写入到tfrecords文件当中
    167         write_to_tfrecords(image_batch, label_batch)
    168 
    169         coord.request_stop()
    170 
    171         coord.join(threads)
     1 # read tfrecords
     2 def read_and_decode():
     3     """
     4     读取验证码数据API
     5     :return: image_batch, label_batch
     6     """
     7     # 1、构建文件队列
     8     file_queue = tf.train.string_input_producer([FLAGS.captcha_dir])
     9 
    10     # 2、构建阅读器,读取文件内容,默认一个样本
    11     reader = tf.TFRecordReader()
    12 
    13     # 读取内容
    14     key, value = reader.read(file_queue)
    15 
    16     # tfrecords格式example,需要解析
    17     features = tf.parse_single_example(value, features={
    18         "image": tf.FixedLenFeature([], tf.string),
    19         "label": tf.FixedLenFeature([], tf.string),
    20     })
    21 
    22     # 解码内容,字符串内容
    23     # 1、先解析图片的特征值
    24     image = tf.decode_raw(features["image"], tf.uint8)
    25     # 1、先解析图片的目标值
    26     label = tf.decode_raw(features["label"], tf.uint8)
    27 
    28     # print(image, label)
    29 
    30     # 改变形状
    31     image_reshape = tf.reshape(image, [20, 80, 3])
    32 
    33     label_reshape = tf.reshape(label, [4])
    34 
    35     print(image_reshape, label_reshape)
    36 
    37     # 进行批处理,每批次读取的样本数 100, 也就是每次训练时候的样本
    38     image_batch, label_btach = tf.train.batch([image_reshape, label_reshape], batch_size=FLAGS.batch_size, num_threads=1, capacity=FLAGS.batch_size)
    39 
    40     print(image_batch, label_btach)
    41     return image_batch, label_btach

    # write flags
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "验证码tfrecords文件")
    tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "验证码图片路径")
    tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "验证码字符的种类")
    # read flags
    tf.app.flags.DEFINE_string("captcha_dir", "./tfrecords/captcha.tfrecords", "验证码数据的路径")
    tf.app.flags.DEFINE_integer("batch_size", 100, "每批次训练的样本数")
    tf.app.flags.DEFINE_integer("label_num", 4, "每个样本的目标值数量")
    tf.app.flags.DEFINE_integer("letter_num", 26, "每个目标值取的字母的可能心个数")
  • 相关阅读:
    KMP算法
    cocos2d-x jsbinding 在线更新策略设计
    AS3动画效果公式,常用处理公式代码,基本运动公式,三角公式
    理解引导行为:路径跟踪
    适用于任何语言的编程策略
    Using中return对象
    js计算两个时间相差天数
    fastReport 绑定DataBand数据源后还是打印出一条数据
    无法处理文件 MainForm.resx,因为它位于 Internet 或受限区域中,或者文件上具有 Web 标记。要想处理这些文件,请删除 Web 标记
    附加数据库后登陆报错
  • 原文地址:https://www.cnblogs.com/jiujue/p/11453779.html
Copyright © 2011-2022 走看看