zoukankan      html  css  js  c++  java
  • TensorFlow读取CSV数据(批量)

    直接上代码:

    # -*- coding:utf-8 -*-
    import tensorflow as tf
    
    def read_data(file_queue):
        reader = tf.TextLineReader(skip_header_lines=1)
        key, value = reader.read(file_queue)
        defaults = [[0], [0.], [0.], [0.], [0.], ['']]
        Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species = tf.decode_csv(value, defaults)
    
        #因为使用的是鸢尾花数据集,这里需要对y值做转换
        preprocess_op = tf.case({
            tf.equal(Species, tf.constant('Iris-setosa')): lambda: tf.constant(0),
            tf.equal(Species, tf.constant('Iris-versicolor')): lambda: tf.constant(1),
            tf.equal(Species, tf.constant('Iris-virginica')): lambda: tf.constant(2),
        }, lambda: tf.constant(-1), exclusive=True)
    
        return tf.stack([SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm]), preprocess_op
    
    def create_pipeline(filename, batch_size, num_epochs=None):
        file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
        example, label = read_data(file_queue)
    
        min_after_dequeue = 1000
        capacity = min_after_dequeue + batch_size
        example_batch, label_batch = tf.train.shuffle_batch(
            [example, label], batch_size=batch_size, capacity=capacity,
            min_after_dequeue=min_after_dequeue
        )
    
        return example_batch, label_batch
    
    x_train_batch, y_train_batch = create_pipeline('Iris-train.csv', 50, num_epochs=1000)
    x_test, y_test = create_pipeline('Iris-test.csv', 60)
    
    init_op = tf.global_variables_initializer()
    local_init_op = tf.local_variables_initializer()  # local variables like epoch_num, batch_size
    with tf.Session() as sess:
        sess.run(init_op)
        sess.run(local_init_op)
    
        # Start populating the filename queue.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    
        # Retrieve a single instance:
        try:
            #while not coord.should_stop():
            while True:
                example, label = sess.run([x_train_batch, y_train_batch])
                print (example)
                print (label)
        except tf.errors.OutOfRangeError:
            print ('Done reading')
        finally:
            coord.request_stop()
    
        coord.join(threads)
        sess.close()

    数据集是鸢尾花数据集,大家自行下载吧,下面给个示例:

    Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
    21,5.4,3.4,1.7,0.2,Iris-setosa
    22,5.1,3.7,1.5,0.4,Iris-setosa
    23,4.6,3.6,1.0,0.2,Iris-setosa
    24,5.1,3.3,1.7,0.5,Iris-setosa
    25,4.8,3.4,1.9,0.2,Iris-setosa
    26,5.0,3.0,1.6,0.2,Iris-setosa
    27,5.0,3.4,1.6,0.4,Iris-setosa
    28,5.2,3.5,1.5,0.2,Iris-setosa
    29,5.2,3.4,1.4,0.2,Iris-setosa
    30,4.7,3.2,1.6,0.2,Iris-setosa
    31,4.8,3.1,1.6,0.2,Iris-setosa
    32,5.4,3.4,1.5,0.4,Iris-setosa
    33,5.2,4.1,1.5,0.1,Iris-setosa
    34,5.5,4.2,1.4,0.2,Iris-setosa
    35,4.9,3.1,1.5,0.1,Iris-setosa
    36,5.0,3.2,1.2,0.2,Iris-setosa
    37,5.5,3.5,1.3,0.2,Iris-setosa
  • 相关阅读:
    SpringMVC:JSON讲解
    SpringMVC:文件上传和下载
    字符串的使用
    python中的作用域与名称空间
    深、浅copy
    代码块与小数据池之间的关系
    关于敏感字符的筛选替换
    列表的增、删、改、查
    最简三级菜单
    python2.x与python3.x的区别
  • 原文地址:https://www.cnblogs.com/hunttown/p/6844477.html
Copyright © 2011-2022 走看看