zoukankan      html  css  js  c++  java
  • tf.Variable()、tf.get_variable()和tf.placeholder()

    1.tf.Variable()

    tf.Variable(initializer,name)

    功能:tf.Variable()创建变量时,name属性值允许重复,检查到相同名字的变量时,由自动别名机制创建不同的变量。

    参数:

    • initializer:初始化参数;
    • name:可自定义的变量名称

    举例:

    import tensorflow as tf
    v1=tf.Variable(tf.random_normal(shape=[2,3],mean=0,stddev=1),name='v1')
    v2=tf.Variable(tf.constant(2),name='v2')
    v3=tf.Variable(tf.ones([2,3]),name='v3')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(v1))
        print(sess.run(v2))
        print(sess.run(v3))
    

    结果如下:

    2.tf.get_variable()

    tf.get_variable(
        name,
        shape=None,
        dtype=None,
        initializer=None,
        regularizer=None,
        trainable=None,
        collections=None,
        caching_device=None,
        partitioner=None,
        validate_shape=True,
        use_resource=None,
        custom_getter=None,
        constraint=None,
        synchronization=tf.VariableSynchronization.AUTO,
        aggregation=tf.VariableAggregation.NONE
    )

    功能:tf.get_variable创建变量时,会进行变量检查,当设置为共享变量时(通过scope.reuse_variables()或tf.get_variable_scope().reuse_variables()),检查到第二个拥有相同名字的变量,就返回已创建的相同的变量;如果没有设置共享变量,则会报[ValueError: Variable varx alreadly exists, disallowed.]的错误。

    参数:

    • name:新变量或现有变量的名称
    • shape:新变量或现有变量的形状
    • dtype:新变量或现有变量的类型(默认为DT_FLOAT)。
    • initializer:变量初始化的方式

    初始化方式:

    • tf.constant_initializer:常量初始化函数
    • tf.random_normal_initializer:正态分布
    • tf.truncated_normal_initializer:截取的正态分布
    • tf.random_uniform_initializer:均匀分布
    • tf.zeros_initializer:全部是0
    • tf.ones_initializer:全是1
    • tf.uniform_unit_scaling_initializer:满足均匀分布,但不影响输出数量级的随机值

    举例:

    v1=tf.Variable(tf.random_normal(shape=[2,3],mean=0,stddev=1),name='v1')
    v2=tf.Variable(tf.random_normal(shape=[2,3],mean=0,stddev=1),name='v1')
    v3=tf.Variable(tf.ones([2,3]),name='v3')
    
    a1 = tf.get_variable(name='a1', shape=[2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
    a2 = tf.get_variable(name='a2', shape=[2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
    a3 = tf.get_variable(name='a3', shape=[2, 3], initializer=tf.ones_initializer())
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        print(sess.run(v1))
        print(sess.run(v2))
        print(sess.run(v3))
        print(sess.run(a1))
        print(sess.run(a2))
        print(sess.run(a3))
    

    v1和v2的参数完全相同,创建时候不会报错;a1和a2的参数完全相同,创建时候会报错  

    3.tf.placeholder()

    tf.placeholder(
        dtype,
        shape=None,
        name=None
    )

    功能:在tensorflow中类似于函数参数,运行时必须传入值。

    TensorFlow链接:https://tensorflow.google.cn/api_docs/python/tf/placeholder?hl=en

    参数:

    • dtype:要进给的张量中的元素类型。常用的是tf.float32,tf.float64等数值类型。
    • shape:要进给的张量的形状(可选)。如果未指定形状,则可以提供任何形状的张量。默认是None,就是一维值,也可以是多维,比如[2,3], [None, 3]表示列是3,行不定。
    • name:操作的名称(可选)。

    举例:

    input1 = tf.placeholder(tf.float32)
    input2 = tf.placeholder(tf.float32)
    
    output = tf.multiply(input1, input2)
    
    with tf.Session() as sess:
        print(sess.run(output, feed_dict={input1: [23.], input2: [4.]})) # [92.]
    

      

    参考文献:

    【1】Tensorflow——tf.Variable()、tf.get_variable()和tf.placeholder()

  • 相关阅读:
    opensuse tumbleweed中安装code
    树莓派中将caplock映射为esc键
    记录一次奇怪但是很有意义的程序编译警告
    新树莓派系统安装ROS记录
    程序的深挖
    intle官方手册下载
    slax linux的定制
    angular4 *ngFor获取index
    axios post传参后台无法接收问题
    AMD、CMD、CommonJs和 ES6对比
  • 原文地址:https://www.cnblogs.com/nxf-rabbit75/p/11276356.html
Copyright © 2011-2022 走看看