zoukankan      html  css  js  c++  java
  • 【转载】 tf.train.slice_input_producer()和tf.train.batch()

    原文地址:

    https://www.jianshu.com/p/8ba9cfc738c2

    ------------------------------------------------------------------------------------------------

    1.          tf.train.slice_input_producer  函数,一种模型数据的排队输入方法。

    tf.train.slice_input_producer(
        tensor_list,
        num_epochs=None, 
        shuffle=True,
        seed=None,
        capacity=32,
        shared_name=None,
        name=None
    )

    其参量为:

    Args:
    tensor_list: A list of Tensor objects. 
    Every Tensor
    in tensor_list must have the same size in the first dimension.

    # 循环Queue输入次数 num_epochs: An integer (optional).
    If specified, slice_input_producer produces each slice num_epochs times before generating an OutOfRange error.
    If not specified, slice_input_producer can cycle through the slices an unlimited number of times.
    shuffle: Boolean. If true, the integers are randomly shuffled within each epoch. seed: An integer (optional). Seed used
    if shuffle == True.

    # Queue的容量 capacity: An integer. Sets the queue capacity. shared_name: (optional). If set, this queue will be shared under the given name across multiple sessions. name: A name for the operations (optional).

    相关代码实例:

        # 生成包含输入和目标图片地址名的list
        input_files = [os.path.join(dirname, 'input', f) for f in flist]
        output_files = [os.path.join(dirname, 'output', f) for f in flist]
    
        # 内部自动转换为Constant String的Tensor,并排队进入队列
        input_queue, output_queue = tf.train.slice_input_producer(
            [input_files, output_files], shuffle=self.shuffle,
            seed=0123, num_epochs=self.num_epochs)
    
        # tf.train.slice_input_producer()每次取一对【输入-目标】对,交给ReadFile这
        # 个Op
        input_file = tf.read_file(input_queue)
        output_file = tf.read_file(output_queue)
        
        # 生成RGB格式的图像tensor
        im_input = tf.image.decode_jpeg(input_file, channels=3)
        im_output = tf.image.decode_jpeg(output_file, channels=3)

    2.          tf.train.batch()函数

    tf.train.batch(
        tensors,
        batch_size,
        num_threads=1,
        capacity=32,
        enqueue_many=False,
        shapes=None,
        dynamic_pad=False,
        allow_smaller_final_batch=False,
        shared_name=None,
        name=None
    )

    其参量为:

    Args:
    tensors: The list or dictionary of tensors to enqueue.
    batch_size: The new batch size pulled from the queue.
    num_threads: The number of threads enqueuing tensors. The batching will be nondeterministic if num_threads > 1.
    capacity: An integer. The maximum number of elements in the queue.
    #进行shuffle的输入是否为单个tensor enqueue_many: Whether each tensor in tensors is a single example. shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for tensors.

    dynamic_pad: Boolean.
    Allow variable dimensions
    in input shapes.
    The given dimensions are padded upon dequeue so that tensors within a batch have the same shapes.

    allow_smaller_final_batch: (Optional) Boolean.
    If True, allow the final batch to be smaller
    if there are insufficient items left in the queue.

    shared_name: (Optional).
    If set, this queue will be shared under the given name across multiple sessions.

    name: (Optional) A name
    for the operations.

    相关代码实例

    samples = tf.train.batch(
            sample,
            batch_size=self.batch_size,
            num_threads=self.nthreads,
            capacity=self.capacity)
  • 相关阅读:
    ab命令做压测测试
    用js两张图片合并成一张图片
    Web全景图的原理及实现
    深入理解Java中的IO
    Spring AOP详解
    spring @Transactional注解参数详解
    优化MyBatis配置文件中的配置
    使用MyBatis对表执行CRUD操作
    @requestBody注解的使用
    url 拼接的一个模块furl
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/10961041.html
Copyright © 2011-2022 走看看