zoukankan      html  css  js  c++  java
  • tensorflow中的batch_normalization实现

      tensorflow中实现batch_normalization的函数主要有两个:

        1)tf.nn.moments

        2)tf.nn.batch_normalization

      tf.nn.moments主要是用来计算均值mean和方差variance的值,这两个值被用在之后的tf.nn.batch_normalization中

      tf.nn.moments(x, axis,...)

      主要有两个参数:输入的batchs数据;进行求均值和方差的维度axis,axis的值是一个列表,可以传入多个维度

      返回值:mean和variance

      tf.nn.batch_normalization(x, mean, variance, offset, scala, variance_epsilon)

      主要参数:输入的batchs数据;mean;variance;offset和scala,这两个参数是要学习的参数,所以只要给出初始值,一般offset=0,scala=1;variance_epsilon是为了保证variance为0时,除法仍然可行,设置为一个较小的值即可

      输出:bn处理后的数据

      具体代码如下:    

    import tensorflow as tf
    import numpy as np
    
    
    X = tf.constant(np.random.uniform(1, 10, size=(3, 3)), dtype=tf.float32)
    axis = list(range(len(X.get_shape()) - 1))
    mean, variance = tf.nn.moments(X, axis)
    print(axis)
    
    X_batch = tf.nn.batch_normalization(X, mean, variance, 0, 1, 0.001)
    
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        mean, variance, X_batch = sess.run([mean, variance, X_batch])
        print(mean)
        print(variance)
        print(X_batch)
    
    输出:

    axis: [0]
    mean: [5.124098 3.0998185 4.723417 ]
    variance: [3.7908943 1.7062012 3.8243492]
    X_batch: [[-0.32879925 -1.3645337 0.39226937]
          [-1.0266179 0.36186576 -1.3726556 ]
          [ 1.355417 1.0026684 0.98038626]]

     
  • 相关阅读:
    关于size_t
    图的搜索算法之迷宫问题和棋盘马走日问题
    螺旋矩阵与螺旋队列
    内存分配问题
    质数的判断
    全局变量、静态全局变量、静态局部变量和局部变量的区别
    程序员必知之代码规范标准
    字符串查找与类型转换(C/C++)
    sizeof与strlen()的用法与区别
    关于C++的输入输出流(cin、sstream和cout)
  • 原文地址:https://www.cnblogs.com/jiangxinyang/p/9394353.html
Copyright © 2011-2022 走看看