zoukankan      html  css  js  c++  java
  • tensorflow prelu的实现细节

    tensorflow prelu的实现细节

    output = tf.nn.leaky_relu(input, alpha=tf_gamma_data,name=name)

    #tf.nn.leaky_relu 限制了tf_gamma_data在[0 1]的范围内 

    内部实现方法是 output = tf.maxmum(alpha * input, input)

    alpha > 1 时,会出现,正值*alpha, 负值不变

    import numpy as np
    import tensorflow as tf
    
    #bn = np.loadtxt('tfbn.txt')
    bn = np.array([[-0.9, -0.9 ,-0.9],[1.1,1.1,1.1]])
    print("srcdata ", bn)
    gamma_data = np.array([1.205321])
    print("gamma_data ", gamma_data)
    tf_gamma_data = tf.Variable(gamma_data, dtype=np.float32)
    input_data = tf.Variable(bn, dtype=np.float32)
    tf_prelu_test = tf.nn.leaky_relu(input_data, alpha=tf_gamma_data,name=None)
    #tf_prelu_test = tf.nn.relu(input_data) + tf.multiply(tf_gamma_data, -tf.nn.relu(-input_data))
    #tf_prelu_test = tf.nn.relu(input_data,name=None)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        tf_prelu_test = sess.run(tf_prelu_test)
        print("tf_prelu_test: 
    ", tf_prelu_test)

    srcdata [[-0.9 -0.9 -0.9]
    [ 1.1 1.1 1.1]]
    gamma_data [1.205321]
    tf_prelu_test:
    [[-0.9 -0.9 -0.9 ]
    [ 1.3258531 1.3258531 1.3258531]]
    [Finished in 2.5s]

    使用relu来代替
    output = tf.nn.relu(data) + tf.multiply(alpha, -tf.nn.relu(-data))

    import numpy as np
    import tensorflow as tf
    
    #bn = np.loadtxt('tfbn.txt')
    bn = np.array([[-0.9, -0.9 ,-0.9],[1.1,1.1,1.1]])
    print("srcdata ", bn)
    gamma_data = np.array([1.205321])
    print("gamma_data ", gamma_data)
    tf_gamma_data = tf.Variable(gamma_data, dtype=np.float32)
    input_data = tf.Variable(bn, dtype=np.float32)
    #tf_prelu_test = tf.nn.leaky_relu(input_data, alpha=tf_gamma_data,name=None)
    tf_prelu_test = tf.nn.relu(input_data) + tf.multiply(tf_gamma_data, -tf.nn.relu(-input_data))
    #tf_prelu_test = tf.nn.relu(input_data,name=None)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        tf_prelu_test = sess.run(tf_prelu_test)
        print("tf_prelu_test: 
    ", tf_prelu_test)

    srcdata [[-0.9 -0.9 -0.9]
    [ 1.1 1.1 1.1]]
    gamma_data [1.205321]
    tf_prelu_test:
    [[-1.0847888 -1.0847888 -1.0847888]
    [ 1.1 1.1 1.1 ]]
    [Finished in 2.7s]

  • 相关阅读:
    AUC ROC PR曲线
    L1,L2范数和正则化 到lasso ridge regression
    目标函数和损失函数
    logistic回归和线性回归
    [转]如何处理不均衡数据?
    将Maven项目打包成可执行 jar文件(引用第三方jar)
    Postgresql VACUUM COPY等
    linux安装xgboost
    java社区推荐
    rabbitmq-java api
  • 原文地址:https://www.cnblogs.com/adong7639/p/9224960.html
Copyright © 2011-2022 走看看