zoukankan      html  css  js  c++  java
  • tensorflow(十六):张量的限幅(tf.clip_by_value()、 tf.clip_by_norm()、 tf.clip_by_global_norm())

    一、tf.clip_by_value()限幅

     

     二、tf.clip_by_norm()根据范数裁剪

    • 等比例缩放,只改变模值大小,不改变方向!

     三、tf.clip_by_global_norm()梯度整体同比例缩放

      梯度爆炸:就是梯度值太大了,每一次前进的步长太长了,导致不停的来回震荡!
    梯度消失:就是梯度的值太小了,每一次前进基本没什么变化,导致loss的值长时间不动。

    四、实战

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import datasets, layers, optimizers
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
    print(tf.__version__)
    
    (x, y), _ = datasets.mnist.load_data()
    x = tf.convert_to_tensor(x, dtype=tf.float32) / 50.   #我们限制0~5范围内,比较大这时候有可能出现gradient exploding
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=10)
    print("x: ", x.shape, "y: ", y.shape)
    
    train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128).repeat(30)
    
    iteration = iter(train_db)
    x,y =next(iteration)
    print("sample: ", x.shape, y.shape)
    
    
    def main():
    
        # 784 => 512
        w1, b1 = tf.Variable(tf.random.truncated_normal([784, 512], stddev=0.1)), tf.Variable(tf.zeros([512]))
        # 512 => 256
        w2, b2 = tf.Variable(tf.random.truncated_normal([512, 256], stddev=0.1)), tf.Variable(tf.zeros([256]))
        # 256 => 10
        w3, b3 = tf.Variable(tf.random.truncated_normal([256, 10], stddev=0.1)), tf.Variable(tf.zeros([10]))
    
        optimizer = optimizers.SGD(lr=0.01)
    
        for step, (x,y) in enumerate(train_db):
            # [batch_size, 28, 28]=> [batch_size, 784]
            x = tf.reshape(x, (-1, 784))
    
            with tf.GradientTape() as tape:
                # layer1
                h1 = x@w1 + b1
                h1 = tf.nn.relu(h1)
                # layer2
                h2=tf.nn.relu(h1@w2 + b2)
                # layer3
                out = tf.nn.relu(h2@w3 + b3)
    
                #computer loss
                # [batch_size, 10] - [batch_size, 10]
                loss = tf.square(out-y)
                # [batch_size, 10] => [batch_size]
                loss = tf.reduce_sum(loss, axis=1)
                # [batch_size] => scalar
                loss = tf.reduce_mean(loss)
    
            #compute gradient
            grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
            print('==before global_norm==')
            for g in grads:
                print(tf.norm(g))  #梯度的模,没有考虑方向
    
            grads, _ =tf.clip_by_global_norm(grads, 15) #梯度的值等比例缩放,都在15以内。
    
            # print('==after global_norm==')
            # for g in grads:
            #     print(tf.norm(g))  #梯度的模,没有考虑方向
    
            # update w' = w - lr*grad
            optimizer.apply_gradients(zip(grads, [w1, b1, w2, b2, w3, b3]))
            
            if step % 100 == 0:
                print(step, 'loss: ', float(loss))
    
    if __name__ == '__main__':
        main()

     如果不加那一行,会出现

  • 相关阅读:
    Spring Boot使用Maven自定义打包方式
    Java操作FileUtils读取数据与写入数据到文件
    将Map中对应的key和value赋值到对象中
    获取List集合对象中某一列属性值
    一文告诉你如何使用java调用http接口
    无音频头音频数组,转写成可播放音频文件
    解析WAV音频文件----》生成WAV音频文件头
    Java中解析wav音频文件信息:音频声道数,采样频率,采样位数、声音尺寸
    jquery click()方法模拟点击事件对a标签不生效
    js speech
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14612227.html
Copyright © 2011-2022 走看看