zoukankan      html  css  js  c++  java
  • Tensorflow简单实践系列(三):图和会话

    当执行一个 TensorFlow 函数的时候,并不会马上执行运算,而是把运算存储到一个称为“图”(graph)的数据结构里面。

    图存储的各种运算,只有在会话(session)里执行图,才会真正地执行。

    图的构建

    对于

    1 c = tf.add(a, b)
    2 e = tf.multiply(c, d) 

    它们所形成的图就是:

    TensorFlow 用 Graph 这个容器数据结构来表示图。图的方法可以分为两类:

    1. 访问图中的数据
    2. 创建 GraphDef

    访问图中的数据

    有这么一些访问图数据的方法:

    • get_tensor_by_name(name):根据 name 返回张量。
    • get_operation_by_name(name):根据 name 返回运算。
    • get_operations():返回运算的列表。
    • get_all_collection_keys():返回集合的列表。
    • get_collection(name, scope=None):返回给定集合的值列表。
    • add_to_collection(name, value):添加值
    • add_to_collections(name, value):添加值

     示例代码:

    1 # 访问图中的数据
    2 x1 = tf.constant(2, name='x1')
    3 x2 = tf.constant(3, name='x2')
    4 my_sum = x1 + x2
    5 print(tf.get_default_graph().get_operations())
    6 print(tf.get_default_graph().get_tensor_by_name('x1:0'))
    [<tf.Operation 'x1' type=Const>, <tf.Operation 'x2' type=Const>, <tf.Operation 'add' type=Add>]
    Tensor("x1:0", shape=(), dtype=int32)

    其中 'x1:0' 表示的是 'name:index',0 表示的是这个张量的索引。

    创建 GraphDef

    GraphDef 是序列化之后的 Graph。

    GraphDef 以一种特殊的格式(protocol buffer 或 protobuf)存储图中的数据。protobuf 可以是二进制格式或者文本格式(长得像 JSON)。

    在 GraphDef 中,所有的张量和运算都用节点来表示。每个节点都有 name/op/attr 这些字段。它的样子就像:

    node {
        name: { ... }
        op: { ... }
        attr { ... }
        attr { ... }
        ...
        versions { ... }
    }

    再通过一段代码来熟悉,as_graph_def 可以访问 TensorFlow 应用中的图:

    1 a = tf.constant(666)
    2 b = tf.constant(777)
    3 sum1 = a + b
    4 print(tf.get_default_graph().as_graph_def())
    node {
      name: "Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 666
          }
        }
      }
    }
    node {
      name: "Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 777
          }
        }
      }
    }
    node {
      name: "add"
      op: "Add"
      input: "Const"
      input: "Const_1"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    versions {
      producer: 38
    }

    tf.train 中的 write_graph 可以把图输出到文件。

    函数签名如下:

    write_graph(graph/graph_def, logdir, name, as_text=True)

    代码示例:

    print(tf.train.write_graph(tf.get_default_graph(), os.getcwd(), 'graph.dat', as_text=True))

    此时会输出:

    /your/path/graph.dat

    即新生成了这个文件。

    创建并运行会话

    在 TensorFlow 里,都是先构建好 Graph,然后再在会话(session)中执行。

    会话的创建

    会话必须显式地创建,通过 tf.Session,它有 3 个参数:

    • target:执行引擎(execution engine)的名称
    • graph:启动的图实例
    • config:配置

    一般我们使用默认参数,那就是:

    1 with tf.Session() as sess:
    2     pass

    会话的执行

    session 最重要的方法就是 run(),它接收 4 个方法:

    • fetches: 指定若干个需要执行的张量或运算
    • feed_dict: 需要喂给张量的数据
    • options: 配置参数
    • run_metadata: 会话的输出数据

    如果 fetches 是一个张量,run 会返回一个和张量等值的 ndarray。

    1 t = tf.constant([6, 66, 666])
    2 with tf.Session() as sess:
    3     res = sess.run(t)
    4     print(res)
    [  6  66 666]

    如果 fetches 是一个运算,run 会返回一个运算之后的 ndarray 值。

    1 t1 = tf.constant(6)
    2 t2 = tf.constant(66)
    3 my_multiply = t1 * t2
    4 
    5 with tf.Session() as sess:
    6     res = sess.run(my_multiply)
    7     print(res)
    396

    如果 fetches 是元素的集合,run 也会返回一个相应的集合。

    1 t1 = tf.constant(6)
    2 t2 = tf.constant(66)
    3 
    4 with tf.Session() as sess:
    5     res1, res2 = sess.run([t1, t2])
    6     print(res1)
    7     print(res2)
    6
    66

    输出到日志

    TensorFlow 的日志是通过 tf.logging 实现的。示例代码:

    1 import tensorflow.compat.v1 as tf
    2 
    3 tf.logging.set_verbosity(tf.logging.INFO)
    4 t = tf.constant(6)
    5 
    6 with tf.Session() as sess:
    7     res = sess.run(t)
    8     tf.logging.info('Output: %f', res)
    I0713 18:06:02.146098 140734845322688 <ipython-input-22-3ef84fc83efc>:8] Output: 6.000000
  • 相关阅读:
    python flask学习笔记
    语音识别2 -- Listen,Attend,and Spell (LAS)
    语音识别 1--概述
    keras中seq2seq实现
    ResNet模型
    Bytes类型
    Python操作文件
    Pyhon基本数据类型
    ping
    find
  • 原文地址:https://www.cnblogs.com/noluye/p/11150657.html
Copyright © 2011-2022 走看看