zoukankan      html  css  js  c++  java
  • tf.train.string_input_producer()

    处理从文件中读数据

    官方说明

    简单使用

    示例中读取的是csv文件,如果要读tfrecord的文件,需要换成 tf.TFRecordReader

    import tensorflow as tf
    filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
    
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    
    # Default values, in case of empty columns. Also specifies the type of the decoded result.
    record_defaults = [[1], [1], [1], [1], [1]]
    col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
    features = tf.stack([col1, col2, col3, col4])
    
    with tf.Session() as sess:
        # Start populating the filename queue.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    
        for i in range(12):
            # Retrieve a single instance:
            example, label = sess.run([features, col5])
            print(example, label)
    
        coord.request_stop()
        coord.join(threads)

    运行结果:

    结合批处理

    import tensorflow as tf
    def read_my_file_format(filename_queue):
    #     reader = tf.SomeReader()
        reader = tf.TextLineReader()
        key, record_string = reader.read(filename_queue)
    #     example, label = tf.some_decoder(record_string)
        record_defaults = [[1], [1], [1], [1], [1]]
        col1, col2, col3, col4, col5 = tf.decode_csv(record_string, record_defaults=record_defaults)
    #     processed_example = some_processing(example)
        features = tf.stack([col1, col2, col3, col4])
        return features, col5
    
    def input_pipeline(filenames, batch_size, num_epochs=None):
        filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
        example, label = read_my_file_format(filename_queue)
        #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
        min_after_dequeue = 100
        capacity = min_after_dequeue + 3 * batch_size
        example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity,
                                  min_after_dequeue=min_after_dequeue)
        return example_batch, label_batch
    
    x,y = input_pipeline(["file0.csv", "file1.csv"],5,4)
    
    sess = tf.Session()
    sess.run([tf.global_variables_initializer(),tf.initialize_local_variables()])
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
        print("in try")
        while not coord.should_stop():
            # Run training steps or whatever
            example, label = sess.run([x,y])
            print(example, label)
            print("ssss")
            
    except tf.errors.OutOfRangeError:
        print ('Done training -- epoch limit reached')
    finally:
        # When done, ask the threads to stop.
        coord.request_stop()
    
    # Wait for threads to finish.
    coord.join(threads)
    sess.close()

    运行结果:

  • 相关阅读:
    今天玩wow的时候遇到的事,汗!
    EF4.0数据DateTime2数据类型错误
    在64位windows server 2003的IIS6上运行32位的.NET程序
    CS0016: 未能写入输出文件“c:\WINDOWS\Microsoft.NET\Framework\.。。”“拒绝访问
    [转]网银支付接口编程资料汇总
    [转]Asp.Net MVC之ViewData字典与ViewModel模式
    MVC3小技巧
    MVC3+Spring.net+NHibernate+ExtJs的简单架构
    我的程序之路 asp.net 接触二年,工作一年总结 (2) lcs
    ASPXspy2.aspx lcs
  • 原文地址:https://www.cnblogs.com/helloworld0604/p/10044748.html
Copyright © 2011-2022 走看看