zoukankan      html  css  js  c++  java
  • tensorflow prelu的实现细节

    tensorflow prelu的实现细节

    output = tf.nn.leaky_relu(input, alpha=tf_gamma_data,name=name)

    #tf.nn.leaky_relu 限制了tf_gamma_data在[0 1]的范围内 

    内部实现方法是 output = tf.maxmum(alpha * input, input)

    alpha > 1 时,会出现,正值*alpha, 负值不变

    import numpy as np
    import tensorflow as tf
    
    #bn = np.loadtxt('tfbn.txt')
    bn = np.array([[-0.9, -0.9 ,-0.9],[1.1,1.1,1.1]])
    print("srcdata ", bn)
    gamma_data = np.array([1.205321])
    print("gamma_data ", gamma_data)
    tf_gamma_data = tf.Variable(gamma_data, dtype=np.float32)
    input_data = tf.Variable(bn, dtype=np.float32)
    tf_prelu_test = tf.nn.leaky_relu(input_data, alpha=tf_gamma_data,name=None)
    #tf_prelu_test = tf.nn.relu(input_data) + tf.multiply(tf_gamma_data, -tf.nn.relu(-input_data))
    #tf_prelu_test = tf.nn.relu(input_data,name=None)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        tf_prelu_test = sess.run(tf_prelu_test)
        print("tf_prelu_test: 
    ", tf_prelu_test)

    srcdata [[-0.9 -0.9 -0.9]
    [ 1.1 1.1 1.1]]
    gamma_data [1.205321]
    tf_prelu_test:
    [[-0.9 -0.9 -0.9 ]
    [ 1.3258531 1.3258531 1.3258531]]
    [Finished in 2.5s]

    使用relu来代替
    output = tf.nn.relu(data) + tf.multiply(alpha, -tf.nn.relu(-data))

    import numpy as np
    import tensorflow as tf
    
    #bn = np.loadtxt('tfbn.txt')
    bn = np.array([[-0.9, -0.9 ,-0.9],[1.1,1.1,1.1]])
    print("srcdata ", bn)
    gamma_data = np.array([1.205321])
    print("gamma_data ", gamma_data)
    tf_gamma_data = tf.Variable(gamma_data, dtype=np.float32)
    input_data = tf.Variable(bn, dtype=np.float32)
    #tf_prelu_test = tf.nn.leaky_relu(input_data, alpha=tf_gamma_data,name=None)
    tf_prelu_test = tf.nn.relu(input_data) + tf.multiply(tf_gamma_data, -tf.nn.relu(-input_data))
    #tf_prelu_test = tf.nn.relu(input_data,name=None)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        tf_prelu_test = sess.run(tf_prelu_test)
        print("tf_prelu_test: 
    ", tf_prelu_test)

    srcdata [[-0.9 -0.9 -0.9]
    [ 1.1 1.1 1.1]]
    gamma_data [1.205321]
    tf_prelu_test:
    [[-1.0847888 -1.0847888 -1.0847888]
    [ 1.1 1.1 1.1 ]]
    [Finished in 2.7s]

  • 相关阅读:
    利用jmeter进行数据库测试
    oracle创建/删除表空间、创建/删除用户并赋予权限
    在linux环境下安装JDK并配置环境变量
    本地与在线图片转Base64及图片预览
    html标签页图标
    Eclipse启动时卡死解决方法
    Java创建目录 mkdir与mkdirs的区别
    Java 获取距离最近一段时间的时间点
    data URI
    JavaScript input file上传前获取文件名、文件类型、文件大小等信息
  • 原文地址:https://www.cnblogs.com/adong7639/p/9224960.html
Copyright © 2011-2022 走看看