zoukankan      html  css  js  c++  java
  • tensorflow batch

    这两天一直在看tensorflow中的读取数据的队列,说实话,真的是很难懂。也可能我之前没这方面的经验吧,最早我都使用的theano,什么都是自己写。经过这两天的文档以及相关资料,并且请教了国内的师弟。今天算是有点小感受了。简单的说,就是计算图是从一个管道中读取数据的,录入管道是用的现成的方法,读取也是。为了保证多线程的时候从一个管道读取数据不会乱吧,所以这种时候 读取的时候需要线程管理的相关操作。今天我实验室了一个简单的操作,就是给一个有序的数据,看看读出来是不是有序的,结果发现是有序的,所以直接给代码:

    import tensorflow as tf
    import numpy as np
    
    def generate_data():
        num = 25
        label = np.asarray(range(0, num))
        images = np.random.random([num, 5, 5, 3])
        print('label size :{}, image size {}'.format(label.shape, images.shape))
        return label, images
    
    def get_batch_data():
        label, images = generate_data()
        images = tf.cast(images, tf.float32)
        label = tf.cast(label, tf.int32)
        input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
        image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
        return image_batch, label_batch
    
    image_batch, label_batch = get_batch_data()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        i = 0
        try:
            while not coord.should_stop():
                image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
                i += 1
                for j in range(10):
                    print(image_batch_v.shape, label_batch_v[j])
        except tf.errors.OutOfRangeError:
            print("done")
        finally:
            coord.request_stop()
        coord.join(threads)

    记得那个slice_input_producer方法,默认是要shuffle的哈。

    Besides, I would like to comment this code. 
    1: there is a parameter ‘num_epochs’ in slice_input_producer, which controls how many epochs the slice_input_producer method would work. when this method runs the specified epochs, it would report the OutOfRangeRrror. I think it would be useful for our control the training epochs. 
    2: the output of this method is one single image, we could operate this single image with tensorflow API, such as normalization, crops, and so on, then this single image is feed to batch method, a batch of images for training or testing would be received

  • 相关阅读:
    怎么才能学好php
    MySQL: ON DUPLICATE KEY UPDATE 用法 避免重复插入数据
    RabbitMQ挂掉问题处理
    页面出现假死的问题
    memkeys 安装时遇到的问题及解决办法
    php 中的$argv与$argc
    PHPExcell单元格中某些时间格式的内容不能正确获得的处理办法
    php中的后期静态绑定("Late Static Binding")
    mybatis从零阅读(一)大纲
    windows 命令
  • 原文地址:https://www.cnblogs.com/Alex0111/p/8493475.html
Copyright © 2011-2022 走看看