zoukankan      html  css  js  c++  java
  • Tensorflow池化

    # -*- encoding: utf-8 -*-
    import tensorflow as tf
    
    # 定义一张4单通道*4图片
    # data = tf.random.truncated_normal(shape=(1, 1, 4, 4))
    data = tf.constant(
        [[[[1, 2, 3, 4],
            [5, 6, 7, 8],
            [9, 10, 11, 12],
            [13, 14, 15, 16]]]],
        dtype="float32"  # avg_pool 要求都是 float32 类型
    )
    
    # reshape 成 batch_size, height, width, n_channels ,因为这是 max_pool函数要求的格式
    # batch_size=1,因为就一张图片, 高和宽都是4,通道是1
    data = tf.reshape(data, [1, 4, 4, 1])
    
    # pool_size 设置成 1,4,1,1; 窗口是[1,1,1,1]
    # 1,4,1,1的数组举个例子:
    # [
    #     [[1]],
    #     [[1]],
    #     [[1]],
    #     [[1]],
    # ]
    # 第一个数字要与batch_size保持一致,后面的shape定义了一个扫描块,也就是纵向按列扫描
    # strides=[1,1,1,1] 每次移动一个单位, 最后输出应是: [1,1,4,1]
    
    output1 = tf.nn.max_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
    print(output1)
    
    # tf.Tensor(
    #     [[[[13]
    #        [14]
    #        [15]
    #        [16]]]], shape=(1, 1, 4, 1), dtype=int32)
    
    output2 = tf.nn.avg_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
    print(output2)
    
    #
    # tf.Tensor(
    #     [[[[ 7.]
    #        [ 8.]
    #        [ 9.]
    #        [10.]]]], shape=(1, 1, 4, 1), dtype=float32)
    
    

    注意事项:

    1. 图片的通道,描述图片用RGB3种颜色,每个颜色都需要一个二维矩阵,成为一个通道
    2. avg_pool 需要输入的数据类型为float, 否则报错:tensorflow.python.framework.errors_impl.NotFoundError: Could not find valid device for node.
    3. 输入数据的格式需要为一个4维数组,shape=(batch_size, height, width, n_channels ) ,这个格式专门为图片设定的,其他类型要自己转换

    池化说明:

    TF提供了tf.keras.layers.AvgPool2D,tf.keras.layers.MaxPool2D 来搭建池化层。

  • 相关阅读:
    第六章实验报告
    第三次实验报告
    循环结构课后反思
    分支结构试验
    第七组509寝室课后习题4.34
    c语言实验报告
    第九章 结构体与共用体
    第八章实验报告(指针)
    第7章 数组实验报告
    函数与宏定义实验报告(2)
  • 原文地址:https://www.cnblogs.com/oaks/p/14028954.html
Copyright © 2011-2022 走看看