zoukankan      html  css  js  c++  java
  • tensorflow(三十八):Batch Normalization

    一、不进行归一化,某些W变化对loss影响较大

     二、进行归一化

    1、可以看到,Batch Norm结束后,只得到三个数值,每个通道一个。

     2、正常的Batch Norm过后,均值为0,方差为1,但是需要再加一个贝塔和伽马。(B,r)需要学出来。

     

     变成了均值为B,方差为r。

    三、用法

    1、下面的center是均值B,scale是方差r。最后一个参数用于测试时候。

     

     

     

     

     

    import tensorflow as tf
    
    from tensorflow import keras
    from tensorflow.keras import layers, optimizers
    
    
    # 2 images with 4x4 size, 3 channels
    # we explicitly enforce the mean and stddev to N(1, 0.5)
    x = tf.random.normal([2,4,4,3], mean=1.,stddev=0.5)
    
    net = layers.BatchNormalization(axis=-1, center=True, scale=True,
                                    trainable=True)
    
    out = net(x)
    print('forward in test mode:', net.variables)
    
    
    out = net(x, training=True)
    print('forward in train mode(1 step):', net.variables)
    
    for i in range(100):
        out = net(x, training=True)
    print('forward in train mode(100 steps):', net.variables)
    
    
    optimizer = optimizers.SGD(lr=1e-2)
    for i in range(10):
        with tf.GradientTape() as tape:
            out = net(x, training=True)
            loss = tf.reduce_mean(tf.pow(out,2)) - 1
    
        grads = tape.gradient(loss, net.trainable_variables)
        optimizer.apply_gradients(zip(grads, net.trainable_variables))
    print('backward(10 steps):', net.variables)
  • 相关阅读:
    oracle中常用的函数
    请求转发和URL重定向的原理和区别
    servlet的生命周期和servlet的继承关系
    Jdbc来操作事物 完成模拟银行的转账业务
    Map的嵌套 练习
    正则表达式练习
    学习 day4 html 盒子模型
    学习day03
    学习day02
    学习day01
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14725679.html
Copyright © 2011-2022 走看看