zoukankan      html  css  js  c++  java
  • tf.train.slice_input_producer()

    tf.train.slice_input_producer处理的是来源tensor的数据

    转载自:https://blog.csdn.net/dcrmg/article/details/79776876 里面有详细参数解释

    官方说明

    简单使用

    import tensorflow as tf
     
    images = ['img1', 'img2', 'img3', 'img4', 'img5']
    labels= [1,2,3,4,5]
     
    epoch_num=8
     
    f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=True)
     
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(epoch_num):
            k = sess.run(f)
            print (i,k)
     
        coord.request_stop()
        coord.join(threads)

    运行结果:

    用tf.data.Dataset.from_tensor_slices调用,之前的会被抛弃,用法:https://blog.csdn.net/qq_32458499/article/details/78856530

    结合批处理

    import tensorflow as tf
    import numpy as np
     
    # 样本个数
    sample_num=5
    # 设置迭代次数
    epoch_num = 2
    # 设置一个批次中包含样本个数
    batch_size = 3
    # 计算每一轮epoch中含有的batch个数
    batch_total = int(sample_num/batch_size)+1
     
    # 生成4个数据和标签
    def generate_data(sample_num=sample_num):
        labels = np.asarray(range(0, sample_num))
        images = np.random.random([sample_num, 224, 224, 3])
        print('image size {},label size :{}'.format(images.shape, labels.shape))
     
        return images,labels
     
    def get_batch_data(batch_size=batch_size):
        images, label = generate_data()
        # 数据类型转换为tf.float32
        images = tf.cast(images, tf.float32)
        label = tf.cast(label, tf.int32)
     
        #从tensor列表中按顺序或随机抽取一个tensor
        input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
     
        image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=64)
        return image_batch, label_batch
     
    image_batch, label_batch = get_batch_data(batch_size=batch_size)
     
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        try:
            for i in range(epoch_num):  # 每一轮迭代
                print ('************')
                for j in range(batch_total): #每一个batch
                    print ('--------')
                    # 获取每一个batch中batch_size个样本和标签
                    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
                    # for k in
                    print(image_batch_v.shape, label_batch_v)
        except tf.errors.OutOfRangeError:
            print("done")
        finally:
            coord.request_stop()
        coord.join(threads)

    运行结果:

  • 相关阅读:
    POJ 3261 Milk Patterns (求可重叠的k次最长重复子串)
    UVaLive 5031 Graph and Queries (Treap)
    Uva 11996 Jewel Magic (Splay)
    HYSBZ
    POJ 3580 SuperMemo (Splay 区间更新、翻转、循环右移,插入,删除,查询)
    HDU 1890 Robotic Sort (Splay 区间翻转)
    【转】ACM中java的使用
    HDU 4267 A Simple Problem with Integers (树状数组)
    POJ 1195 Mobile phones (二维树状数组)
    HDU 4417 Super Mario (树状数组/线段树)
  • 原文地址:https://www.cnblogs.com/helloworld0604/p/10044538.html
Copyright © 2011-2022 走看看