zoukankan      html  css  js  c++  java
  • Tensorflow--池化操作的梯度

    Tensorflow–池化操作的梯度

    池化操作的梯度分两部分介绍,第一部分介绍平均值池化的梯度计算,第二部分介绍最大值池化的梯度计算

    一.平均值池化的梯度

    利用计算梯度的函数gradients实现上述示例,具体代码如下:

    import tensorflow as tf
    import numpy as np
    
    # x是1个3行3列1深度的张量
    x=tf.placeholder(tf.float32,(1,3,3,1))
    
    # 2x2的掩码,步长是(1,1,1,1)的valid平均值池化操作
    sigma=tf.nn.avg_pool(x,(1,2,2,1),(1,1,1,1),'VALID')
    
    # 构造一个函数F:池化结果的和
    F=tf.reduce_sum(sigma)
    
    session=tf.Session()
    
    xvalue=np.random.randn(1,3,3,1)
    grad=tf.gradients(F,[sigma,x])
    results=session.run(grad,{x:xvalue})
    
    print("---针对sigma的梯度---:")
    print(results[0])
    print("---针对x的梯度---:")
    print(results[1])
    
    ---针对sigma的梯度---:
    [[[[1.]
       [1.]]
    
      [[1.]
       [1.]]]]
    ---针对x的梯度---:
    [[[[0.25]
       [0.5 ]
       [0.25]]
    
      [[0.5 ]
       [1.  ]
       [0.5 ]]
    
      [[0.25]
       [0.5 ]
       [0.25]]]]
    

    二.最大值池化的梯度

    import tensorflow as tf
    
    # 初始化x的值
    x=tf.Variable(tf.constant([
                               [
                               [[8],[2],[9],[3]],
                               [[4],[6],[7],[10]],
                               [[20],[13],[1],[5]],
                               [[12],[18],[19],[14]]
                               ]
                               ],tf.float32),dtype=tf.float32)
    
    # 2x2的掩码,步长为2x2的最大值池化操作
    x_maxPool=tf.nn.max_pool(x,(1,2,2,1),(1,2,2,1),'VALID')
    
    # 对以上最大值池化结果计算其平方和
    F=tf.reduce_sum(tf.square(x_maxPool))
    
    session=tf.Session()
    session.run(tf.global_variables_initializer())
    
    opti=tf.train.GradientDescentOptimizer(0.5).minimize(F)
    
    # 打印前2次结果
    for i in range(2):
        session.run(opti)
        print(session.run(x))
    
    [[[[ 0.]
       [ 2.]
       [ 9.]
       [ 3.]]
    
      [[ 4.]
       [ 6.]
       [ 7.]
       [ 0.]]
    
      [[ 0.]
       [13.]
       [ 1.]
       [ 5.]]
    
      [[12.]
       [18.]
       [ 0.]
       [14.]]]]
    [[[[ 0.]
       [ 2.]
       [ 0.]
       [ 3.]]
    
      [[ 4.]
       [ 0.]
       [ 7.]
       [ 0.]]
    
      [[ 0.]
       [13.]
       [ 1.]
       [ 5.]]
    
      [[12.]
       [ 0.]
       [ 0.]
       [ 0.]]]]
    
  • 相关阅读:
    云时代架构阅读笔记十一——分布式架构中数据一致性常见的几个问题
    云时代架构阅读笔记十——支付宝架构师眼中的高并发架构
    云时代架构阅读笔记九——Disruptor无锁框架为啥这么快
    云时代架构阅读笔记八——JVM性能调优
    lightoj 1024 (高精度乘单精度)
    lightoj 1023
    lightoj 1022
    codeforces 260 div2 C题
    codeforces 260 div2 B题
    codedorces 260 div2 A题
  • 原文地址:https://www.cnblogs.com/LQ6H/p/10343263.html
Copyright © 2011-2022 走看看