zoukankan      html  css  js  c++  java
  • Tensorflow--池化操作的梯度

    Tensorflow–池化操作的梯度

    池化操作的梯度分两部分介绍,第一部分介绍平均值池化的梯度计算,第二部分介绍最大值池化的梯度计算

    一.平均值池化的梯度

    利用计算梯度的函数gradients实现上述示例,具体代码如下:

    import tensorflow as tf
    import numpy as np
    
    # x是1个3行3列1深度的张量
    x=tf.placeholder(tf.float32,(1,3,3,1))
    
    # 2x2的掩码,步长是(1,1,1,1)的valid平均值池化操作
    sigma=tf.nn.avg_pool(x,(1,2,2,1),(1,1,1,1),'VALID')
    
    # 构造一个函数F:池化结果的和
    F=tf.reduce_sum(sigma)
    
    session=tf.Session()
    
    xvalue=np.random.randn(1,3,3,1)
    grad=tf.gradients(F,[sigma,x])
    results=session.run(grad,{x:xvalue})
    
    print("---针对sigma的梯度---:")
    print(results[0])
    print("---针对x的梯度---:")
    print(results[1])
    
    ---针对sigma的梯度---:
    [[[[1.]
       [1.]]
    
      [[1.]
       [1.]]]]
    ---针对x的梯度---:
    [[[[0.25]
       [0.5 ]
       [0.25]]
    
      [[0.5 ]
       [1.  ]
       [0.5 ]]
    
      [[0.25]
       [0.5 ]
       [0.25]]]]
    

    二.最大值池化的梯度

    import tensorflow as tf
    
    # 初始化x的值
    x=tf.Variable(tf.constant([
                               [
                               [[8],[2],[9],[3]],
                               [[4],[6],[7],[10]],
                               [[20],[13],[1],[5]],
                               [[12],[18],[19],[14]]
                               ]
                               ],tf.float32),dtype=tf.float32)
    
    # 2x2的掩码,步长为2x2的最大值池化操作
    x_maxPool=tf.nn.max_pool(x,(1,2,2,1),(1,2,2,1),'VALID')
    
    # 对以上最大值池化结果计算其平方和
    F=tf.reduce_sum(tf.square(x_maxPool))
    
    session=tf.Session()
    session.run(tf.global_variables_initializer())
    
    opti=tf.train.GradientDescentOptimizer(0.5).minimize(F)
    
    # 打印前2次结果
    for i in range(2):
        session.run(opti)
        print(session.run(x))
    
    [[[[ 0.]
       [ 2.]
       [ 9.]
       [ 3.]]
    
      [[ 4.]
       [ 6.]
       [ 7.]
       [ 0.]]
    
      [[ 0.]
       [13.]
       [ 1.]
       [ 5.]]
    
      [[12.]
       [18.]
       [ 0.]
       [14.]]]]
    [[[[ 0.]
       [ 2.]
       [ 0.]
       [ 3.]]
    
      [[ 4.]
       [ 0.]
       [ 7.]
       [ 0.]]
    
      [[ 0.]
       [13.]
       [ 1.]
       [ 5.]]
    
      [[12.]
       [ 0.]
       [ 0.]
       [ 0.]]]]
    
  • 相关阅读:
    Shell编程-05-Shell中条件测试与比较
    Shell编程-04-Shell中变量数值计算
    Shell编程-03-Shell中的特殊变量和扩展变量
    Shell编程-02-Shell变量
    Shell编程-01-Shell脚本初步入门
    Windows与Linux相互远程桌面连接
    awk基础05-自定义函数和脚本
    使用Kafka Connect创建测试数据生成器
    设置KAFKA
    Apache Kafka使用默认配置执行一些负载测试来完成性能测试和基准测试
  • 原文地址:https://www.cnblogs.com/LQ6H/p/10343263.html
Copyright © 2011-2022 走看看