zoukankan      html  css  js  c++  java
  • TFRecord 使用

    tfrecord生成

    import os
    import xmltodict
    import tensorflow as tf
    import numpy as np
    
    dir_path = 'F:数据存储VOCdevkitVOC2012Annotations'
    dirs = os.listdir(dir_path)
    imgs_dir = "F:数据存储VOCdevkitVOC2012JPEGImages"
    out_path = 'F:数据存储VOCdevkit\voc2012.tfrecord'
    
    classes = [
        "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
        "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
        "pottedplant", "sheep", "sofa", "train", "tvmonitor"
    ]
    sess = tf.Session()
    
    
    def get_and_resize_img(img_file):
        '''
        将图片设置为224*224的尺寸大小
        返回图片,返回变化倍数,shape
        '''
        img = tf.read_file(imgs_dir + '/' + img_file)
        img = tf.image.decode_jpeg(img)
        shape_old = sess.run(img).shape
        resized_img = tf.image.resize_images(img, [224, 224], method=0)
        resized_img = sess.run(resized_img)
        resized_img = np.asarray(resized_img, dtype='uint8')
        resized_img_str = resized_img.tostring()
        shape_new = resized_img.shape
        # print(shape_new)
        # print(shape_old)
        # print('shape_old的长是width是维度1,height是维度0')
        w_scale = shape_new[0] / shape_old[1]
        h_scale = shape_new[1] / shape_old[0]
    
        return resized_img_str, w_scale, h_scale, shape_new
    
    
    writer = tf.python_io.TFRecordWriter(out_path)
    
    i = 0
    for file in dirs:
        i = i + 1
        # if i > 1000:
        #     break
        with open(dir_path + '/' + file) as xml_txt:
            doc = xmltodict.parse(xml_txt.read())
            img_file_name = file.split('.')[0]
            resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg')
            img_obtain_classes = []
            y_mins = []
            x_mins = []
            y_maxes = []
            x_maxes = []
            if type(doc['annotation']["object"]).__name__ == 'OrderedDict':
                if doc['annotation']["object"]['name'] in classes:
                    img_obtain_classes.append(classes.index(doc['annotation']["object"]['name']))
                    y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin'])))
                    x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin'])))
                    y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax'])))
                    x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax'])))
            else:
                for one_object in doc['annotation']["object"]:
                    # ['annotation']["object"][0]["name"]
                    if one_object['name'] in classes:
                        img_obtain_classes.append(classes.index(one_object['name']))
                        y_mins.append(float(h_scale * int(one_object['bndbox']['ymin'])))
                        x_mins.append(float(w_scale * int(one_object['bndbox']['xmin'])))
                        y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax'])))
                        x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax'])))
            # example = tf.train.Example(features=tf.train.Features(feature={
            #     'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
            #     'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
            #     'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
            # }
            # ))
            img_file_name = bytes(img_file_name, encoding='utf8')
    
            example = tf.train.Example(features=tf.train.Features(feature={
                'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
                'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
                'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)),
                'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)),  # 各个 object 的  ymin
                'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
                'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
                'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
                'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
            }))
            writer.write(example.SerializeToString())
    writer.close()
    sess.close()
    print('ok')
    
    

    tfrecord读取

    import tensorflow as tf
    import numpy as np
    from matplotlib import pyplot as plt
    # import sys
    #
    # sys.path.append("..")
    
    classes = [
        "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
        "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
        "pottedplant", "sheep", "sofa", "train", "tvmonitor"
    ]
    
    
    
    # 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
    # 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
    # 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))),
    # 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)),  # 各个 object 的  ymin
    # 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
    # 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
    # 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
    # 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
    
    def _parse_record(example_proto):
        features = {
            'filename': tf.FixedLenFeature([], tf.string),
            'shape': tf.FixedLenFeature([3], tf.int64),
            'classes': tf.VarLenFeature(tf.int64),
            'y_mins': tf.VarLenFeature(tf.float32),
            'x_mins': tf.VarLenFeature(tf.float32),
            'y_maxes': tf.VarLenFeature(tf.float32),
            'x_maxes': tf.VarLenFeature(tf.float32),
            'encoded': tf.FixedLenFeature((), tf.string)
        }
        parsed_features = tf.parse_single_example(example_proto, features=features)
        return parsed_features
    
    
    def read_test(input_file):
        # 用 dataset 读取 tfrecord 文件
        dataset = tf.data.TFRecordDataset(input_file)
        dataset = dataset.map(_parse_record)
        iterator = dataset.make_initializable_iterator()
        max_value = tf.placeholder(tf.int64, shape=[])
        with tf.Session() as sess:
            sess.run(iterator.initializer, feed_dict={max_value: 100})
            for i in range(2):
                features = sess.run(iterator.get_next())
                name = features['filename']
                name = name.decode()
                shape = features['shape']
                classes = features['classes']
                y_mins = features['y_mins']
                x_mins = features['x_mins']
                y_maxes = features['y_maxes']
                x_maxes = features['x_maxes']
                # name = name.decode()
                img_data = features['encoded']
    
                print(len(img_data))
                print('=======')
                print("shape", shape)
                print("name", name)
                print("classes", classes.values)
                print("y_mins", y_mins.values)
                print("x_mins", x_mins.values)
                print("y_maxes", y_maxes.values)
                print("x_maxes", x_maxes.values)
                img_data = np.fromstring(img_data, dtype=np.uint8)
                image_data = np.reshape(img_data, shape)
                print("img_data", image_data)
                # 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
                # img_data = np.fromstring(img_data, dtype=np.uint8)
                # image_data = np.reshape(img_data, shape)
                #
                # plt.figure()
                # # 显示图片
                plt.imshow(image_data)
                plt.show()
    
    
    read_test('F:数据存储VOCdevkit\voc2012.tfrecord')
    
    
    

    尺寸不固定矩阵的存储和读取

    import json
    import jieba
    import tensorflow as tf
    
    with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file:
        dic = json.loads(file.read())
        all_words_word2id = dic["all_words_word2id"]
    
    stop_words = []
    with open('./stop_words.txt', encoding='utf-8') as f:
        line = f.readline()
        while line:
            stop_words.append(line[:-1])
            line = f.readline()
    stop_words = set(stop_words)
    print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words)))
    
    dir_path = 'F:\数据存储新闻语料\news2016zh_train.json'
    dir_path_test = 'F:\数据存储新闻语料\news2016zh_valid.json'
    out_path = 'F:\数据存储新闻语料\news2016zh_train_new.tfrecord'
    
    
    def getCutSequnce(line):
        # 使用jieba 进行中文分词
        raw_words = list(jieba.cut(line, cut_all=False))
        # 存储一句话的分词结果
        raw_word_list = []
        # 去除停用词
        for word in raw_words:
            if word not in stop_words and word not in ['www', 'com', 'http']:
                raw_word_list.append(word)
    
        return raw_word_list
    
    
    writer = tf.python_io.TFRecordWriter(out_path)
    i = 0
    
    with open(dir_path, encoding='utf-8') as txt:
        one_dic = txt.readline()
        while one_dic:
            i = i + 1
            if i > 10000:
                break
            if (i % 1000) == 0:
                print(i)
            one_dic_json = json.loads(one_dic)
    
            title = one_dic_json['title']
            content = one_dic_json['content']
            if len(content) > 3000:
                one_dic = txt.readline()
                continue
            one_dic = txt.readline()
    
            if len(title) == 0 or len(content) == 0:
                continue
            title_list = getCutSequnce(title)
            content_list = getCutSequnce(content)
            title_list_index = []
            for one in title_list:
                try:
                    title_list_index.append(all_words_word2id[one])
                except:
                    pass
    
            content_list_index = []
            for one_word in content_list:
                try:
                    content_list_index.append(all_words_word2id[one_word])
                except:
                    pass
    
            example = tf.train.Example(features=tf.train.Features(feature={
                'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)),
                'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index))
            }))
            writer.write(example.SerializeToString())
    
    
    
    
    
    
    import tensorflow as tf
    import numpy as np
    
    def _parse_record(example_proto):
        features = {
            'title': tf.VarLenFeature(tf.int64),
            'content': tf.VarLenFeature(dtype=tf.int64)
        }
        parsed_features = tf.parse_single_example(example_proto, features=features)
        return parsed_features
    
    def read_test(input_file):
        # 用 dataset 读取 tfrecord 文件
        dataset = tf.data.TFRecordDataset(input_file)
        dataset = dataset.map(_parse_record)
        iterator = dataset.make_initializable_iterator()
        with tf.Session() as sess:
            sess.run(iterator.initializer)
            for i in range(5):
                features = sess.run(iterator.get_next())
                name = features['title']
                content = features['content']
    
                print("xx", content)
                print("xx", np.array(content).shape)
                # 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
               
    read_test('F:\数据存储新闻语料\news2016zh_train_new.tfrecord')
    
    
    

    统计数据条数

    import tensorflow as tf
    
    
    def total_sample(file_name):
        sample_nums = 0
        for record in tf.python_io.tf_record_iterator(file_name):
            sample_nums += 1
        return sample_nums
    
    
    result = total_sample('F:\数据存储新闻语料\news2016zh_train_new.tfrecord')
    print(result)
    
    
  • 相关阅读:
    谷歌浏览器中安装JsonView扩展程序
    谷歌浏览器中安装Axure扩展程序
    PreferencesUtils【SharedPreferences操作工具类】
    Eclipse打包出错——提示GC overhead limit exceeded
    IntentActionUtil【Intent的常见作用的工具类】
    DeviceUuidFactory【获取设备唯一标识码的UUID(加密)】【需要运行时权限的处理的配合】
    AndroidStudio意外崩溃,电脑重启,导致重启打开Androidstudio后所有的import都出错
    DateTimeHelper【日期类型与字符串互转以及日期对比相关操作】
    ACache【轻量级的开源缓存框架】
    WebUtils【MD5加密(基于MessageDigest)】
  • 原文地址:https://www.cnblogs.com/panfengde/p/11302960.html
Copyright © 2011-2022 走看看