zoukankan      html  css  js  c++  java
  • tensorflow基础【3】-Session、placeholder、feed、fetch

    tensorflow 的使用逻辑:

    用 Tensor 表示数据;

    用 Variable 维护状态;

    用 Graph 图表示计算任务,图中的节点称为 op(operation);

    用 Session 会话执行 Graph;

    用 feed 和 fetch 为操作输入和输出数据;

    Session

    tensorflow 是静态图,graph 定义好之后,需要在会话中执行;

    Session 相当于一个环境,负责计算张量,执行操作,他将 op 发送给 GPU 或者 CPU 之类的设备上,并提供执行方法,输出 Tensor;

    在 python 中,输出的 Tensor 为 numpy 的 ndarray 类型;

    返回 Tensor 的操作才需要 Session

    如输出元组,无需 session

    d1 = tf.random_uniform((3, 2)).shape        ### Tensor
    print(d1)

    创建会话

    class Session(BaseSession):
        def __init__(self, target='', graph=None, config=None):
            pass

    参数说明:

    • target:会话连接的执行引擎
    • graph:会话加载的数据流图
    • config:会话启动时的配置项

    如果不输入任何参数,代表启动默认图    【后期会在 Graph 中解释】

    两种启动方法

    ### method1
    sess = tf.Session()
    sess.close()            ### 显式关闭 session
    
    ### method2
    with tf.Session() as sess:      ### 自动关闭 session
        sess.run()

    1. Session 对象使用完毕通常需要关闭以释放资源,当然不关闭也可;

    2. sess.run 方法执行 op

    3. 在会话中执行 graph 通常需要 feed 和 fetch 操作

    交互式环境启动 Session

    例如在 IPython 中启动 Session

    sess = tf.InteractiveSession()

    这种方式下通常用 tensor.eval(session) 和 operation.run() 代替 sess.run

    补充

    占位符

    def placeholder(dtype, shape=None, name=None)

    注意

    1. 如果 shape 为 一维,如 label = [1,1,0,0],要这样写

    label = tf.placeholder(tf.float32, shape=[batch, ])

    2. 如果 shape 第一个元素为 None,可能出现如下错误 ValueError: Cannot convert an unknown Dimension to a Tensor: ?;

    此时可以把 None 改为固定值

    batch = 10
    x = tf.placeholder(tf.float32, shape=(batch, 1))

    3. 定义了 placeholder 后,不要在用 这个变量名了

    x = tf.placeholder(tf.float32, shape=(None,))
    x = 3   ### no no no 前面已经用过了

    feed

    即喂数据,通常需要用 占位符 来填坑;

    喂给的数据不能是 tensor,只能是 python 的数据类型;

    The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles

    也可以是 fetch 的结果,因为 op 的 fetch 为 numpy;

    示例

    d1 = tf.placeholder(dtype=tf.float32, shape=[2, 2])
    d2 = tf.placeholder(dtype=tf.float32, shape=[2, 1])
    d3 = tf.matmul(d1, d2)
    
    sess2 = tf.Session()
    a = np.array([[1.,2.], [2., 3.]])
    b = np.array([[4.], [4.]])
    print(sess2.run(d3, feed_dict={d1:a, d2:b}))
    # [[12.]
    #  [20.]]

    fetch

    即获取,获取节点的输出,

    可以获取单个节点的输出,也可以同时执行多个节点,获取多个节点的输出

    d1 = tf.placeholder(dtype=tf.float32, shape=[2, 2])
    d2 = tf.placeholder(dtype=tf.float32, shape=[2, 1])
    d3 = tf.matmul(d1, d2)
    
    sess2 = tf.Session()
    a = np.array([[1.,2.], [2., 3.]])
    b = np.array([[4.], [4.]])
    print(sess2.run(d3, feed_dict={d1:a, d2:b}))
    # [[12.]
    #  [20.]]
    
    print(type(sess2.run(d3, feed_dict={d1:a, d2:b})))      # <class 'numpy.ndarray'>

    可以看到输出的 Tensor 为 ndarray 类型

    循环 fetch 时,fetch 的变量名 不要和 op 的名字一样,否则,变量名 覆盖了 op 名,下次你的 op 就不是 tensor 而是 ndarray 了

    op1, op2 = tf....
    sess = tf.Session()
    for i in range(10):
        op1, op2 = sess.run([op1, op2])     # no

    如果乱写,报错 must be a string or Tensor. (Can not convert a float32 into a Tensor or Operation.)

    参考资料:

  • 相关阅读:
    Shell 学习笔记之函数
    Shell 学习笔记之条件语句
    Shell 学习笔记之运算符
    Shell 学习笔记之变量
    [LeetCode] Zuma Game 题解
    [LeetCode] Decode String 题解
    [LeetCode] Pacific Atlantic Water Flow 题解
    TCP的建立和终止 图解
    [LeetCode] 01 Matrix 题解
    java中protect属性用法总结
  • 原文地址:https://www.cnblogs.com/yanshw/p/12345722.html
Copyright © 2011-2022 走看看