zoukankan      html  css  js  c++  java
  • tf.train.batch的偶尔乱序问题

    tf.train.batch的偶尔乱序问题

    我的微博我的github我的B站

    tf.train.batch的偶尔乱序问题

    • 我们在通过tf.Reader读取文件后,都需要用batch函数将读取的数据根据预先设定的batch_size打包为一个个独立的batch方便我们进行学习。
    • 常用的batch函数有tf.train.batch和tf.train.shuffle_batch函数。前者是将数据从前往后读取并顺序打包,后者则要进行乱序处理————即将读取的数据进行乱序后在组成批次。
    • 训练时我往往都是使用shuffle_batch函数,但是这次我在验证集上预调好模型并freeze模型后我需要在测试集上进行测试。此时我需要将数据的标签和inference后的结果进行一一对应。 此时数据出现的顺序是十分重要的,这保证我们的产品在上线前的测试集中能准确get到每个数据和inference后结果的差距 而在验证集中我们不太关心数据原有的标签和inference后的真实值,我们往往只是需要让这两个数据一一对应,关于数据出现的顺序我们并不关心。
    • 此时我们一般使用tf.train.batch函数将tf.Reader读取的值进行顺序打包即可。

    然而tf.train.batch函数往往会有偶尔乱序的情况

    • 我们将csv文件中每个数据样本从上往下依次进行标号,我们在使用tf.trian.batch函数依次进行读取,如果我们读取的数据编号乱序了,则表明tf.train.batch函数有偶尔乱序的状况。

    源程序文件下载
    test_tf_train_batch.csv

    import tensorflow as tf
    
    BATCH_SIZE = 400
    NUM_THREADS = 2
    MAX_NUM = 500
    
    
    def read_data(file_queue):
        reader = tf.TextLineReader(skip_header_lines=1)
        key, value = reader.read(file_queue)
        defaults = [[0], [0.], [0.]]
        NUM, C, Tensile = tf.decode_csv(value, defaults)
        vertor_example = tf.stack([C])
        vertor_label = tf.stack([Tensile])
        vertor_num = tf.stack([NUM])
    
        return vertor_example, vertor_label, vertor_num
    
    
    def create_pipeline(filename, batch_size, num_threads):
        file_queue = tf.train.string_input_producer([filename])  # 设置文件名队列
        example, label, no = read_data(file_queue)  # 读取数据和标签
    
        example_batch, label_batch, no_batch = tf.train.batch(
            [example, label, no], batch_size=batch_size, num_threads=num_threads, capacity=MAX_NUM)
    
        return example_batch, label_batch, no_batch
    
    
    x_train_batch, y_train_batch, no_train_batch = create_pipeline('test_tf_train_batch.csv', batch_size=BATCH_SIZE,
                                                                   num_threads=NUM_THREADS)
    
    init_op = tf.global_variables_initializer()
    local_init_op = tf.local_variables_initializer()
    with tf.Session() as sess:
        sess.run(local_init_op)
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        example, label, num = sess.run([x_train_batch, y_train_batch, no_train_batch])
        print(example)
        print(label)
        print(num)
        coord.request_stop()
        coord.join(threads)
    
    

    实验结果

    我们将csv文件中的真实Tensile值放在第一列,将使用tf.train.batch函数得到的Tensile和no分别放在第二列和第三列

    TureTensile FalseTensile NO
    0.830357143 [ 0.52678573] [ 66]
    0.526785714 [ 0.83035713] [ 65]
    0.553571429 [ 0.4375 ] [ 68]
    0.4375 [ 0.5535714 ] [ 67]
    0.517857143 [ 0.33035713] [ 70]
    0.330357143 [ 0.51785713] [ 69]
    0.482142857 [ 0.6785714 ] [ 72]
    0.678571429 [ 0.48214287] [ 71]
    0.419642857 [ 0.02678571] [ 74]
    0.026785714 [ 0.41964287] [ 73]
    0.401785714 [ 0.4017857 ] [ 75]

    解决方案

    • 将测试集中所有样本数据加NO顺序标签列
  • 相关阅读:
    推荐20个开源项目托管网站
    python 网络编程(网络基础之网络协议篇)
    python 异常处理
    python 内置函数的补充 isinstance,issubclass, hasattr ,getattr, setattr, delattr,str,del 用法,以及元类
    python3 封装之property 多态 绑定方法classmethod 与 非绑定方法 staticmethod
    python3 类 组合
    PYTHON3中 类的继承
    面向对象 与类
    包 与常用模块
    json 与pickle模块(序列化与反序列化))
  • 原文地址:https://www.cnblogs.com/cloud-ken/p/9092010.html
Copyright © 2011-2022 走看看