zoukankan      html  css  js  c++  java
  • TensorFlow学习之graph和session

    引言:

             按照我的理解,graph相当于一块空白的面包板(电路板,可以在上面插电路元件设计电路,供电调试),上面搭建了许多子模块,自顶向下构成一整个功能电路,整个模块有输入输出端口,里面的子模块或者部分子模块

    构成的模块也有输入和输出端口。搭建好的这个电路板就是 'graph',而测试(执行)相应模块都需要开一个session,上面有很多模块可以调试,也就是说一个graph可以开多个session来执行各自模块的输入输出过程。借网上dao来的图作个说明:

       

     

     接下来看一小段代码:

    参考链接:https://blog.csdn.net/vinceee__/article/details/88075451   https://blog.csdn.net/vinceee__/article/details/88075451

    注:以下伪代码仅帮助理解

     1 val = tf.variable(initilizer, ...)  #变量定义
     2 const_1 = tf.constant()          #常量定义
     3 data =3 #定义
     4 op = tf.assign(val, const_1) #定义赋值操作
     5 c = tf.placeholder(dtype= ,shape = [])  #placeholder占位符定义
     6 #使用之前定义的变量, 需要定义一个初始化操作init, sess开始后执行变量初始化
     7 init = tf.global_varibales_initializer()#全局变量初始化操作
     8  with tf.session() as sess:
     9        sess.run(init)   #执行初始化操作
    10        sess.run(op)    #执行操作: 将常量const_1赋值给变量val
    11        print(sess.run(val))
    12        print(sess.run(c, feed_dict = {c:data})#将data值传递到c这个占位符,执行打印 c

    以上代码没有新建graph,默认使用的是TensorFlow的全局的defaultgraph.如果需要使用自己定义的graph,则使用with tf.graph().as_default()来取代默认的全局计算图  

    一、GRAPH

     1、有关graph

    计算图graph 由许多个node 组成,node又由name(名称),op(操作),input(输入),attrs(属性:dtype,shape, size, value... ...构成),对于tf.constant()函数只会产生一个node,但是对于tf.varibale(initialize, name),其中生成一个initializer初始化器,一共会产生三个node:

    1)variable:变量维护(不存放实际的值)

    2) Varibale/assign: 变量分配

    3) variable/read: 变量读取使用

    使用变量时需要进行变量初始化:生成初始化器对象,再在sess中执行init.

    2、graph的使用

     1) Tf.get_default_graph():    获得当前默认graph

     2) With  Tf.Graph().as_default():  tf.graph()创建一个新的graph,并通过as_default()设置为上下文中的默认图(局部的),即with下面所有的定义都是在这个新建的 (局部的)graph中。

        With之外另外的graph作为默认,get_default_graph可以获得当前的图。

     3) tensorflow 有一个默认的全局graph, 不用定义可直接使用,若需要另外的graph则需要新建(tf.graph()在当前使用with范围内设置为default)

    二、Session

     1、有关session

           tf.session 是用来运行TensorFlow操作的类,一个session对象封装了操作执行对象的环境,在这个环境下才可以对tensor对象进行计算,tensor对象不可直接进行计算操作,同时也会负责分配计算资源和变量存放:

    1)步骤: build a graph       #构建一个计算图(里面有操作、tensor对象)

                   launch the graph in a session  #基于以上创建的graph创建一个session

                   evaluate the tensor         #执行计算

    2) session: 包括一些资源,例如variables, queues, readers 当我们不再需要这些资源时候,我们需要释放这些资源,释放资源可以通过执行sess.close(),或者将session作为上下文管理器使用:

    with tf.session() as sess:

           .... ....

           sess.run()

    执行到with范围以外,session资源释放

    3) configProto protocol buffer 公开session的各种配置选项。

            ConfigProto详解:

            

                tf.configProto函数用在创建session的时候对session进行参数配置。

                3.1)Tf.ConfigProto(log­_device_placement=True)

                        log­_device_placement=True,可以获取到operations 和tensor 被指派到哪个设备(几号CPU/GPU)上运行,会在终端打印出各项操作在哪个设备上面运行。

                3.2) Tf.ConfigProto(allow_soft_placement=True)

                     在tf中,通过“with tf.device('cpu:0'):”来手动设置操作运行的设备。如果手动设置的设备不存在或者不可用,就会导致程序等待或者异常,为了防止这种情况,allow_soft_placement=True,可以允许tf自动选择一个存在并可用的设备

                     来运行操作。

               3.3) 限制GPU的资源使用:

                        为了加快运行效率,tf在初始化时候会尝试分配所有可用GPU显存资源给自己,如果在多人使用的服务器上工作会导致别人无法正常使用GPU。

                       tf提供了两种控制GPU资源使用的方法:

                      3.3.1) 设置动态申请显存,需要多少就申请多少

                               

    sess_config = tf.configProto(log_device_placement=False, allow_soft_placement = True)
    sess_config.gpu_options.allow_growth = True
    session = tf.session(config = sess_config)

                      3.3.2) 限制GPU的使用率,如下两种方法:

                        

    sess_config = tf.configProto(log_device_placement= Fa;se ,allow_soft_placementt= True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.4#占用40%显存
    session = tf.session(config = sess_config)
    
    
    gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction= 0.4)
    sess_config = tf.configProto(gpu_options = gpu_options)

                    3.3.3) 设置使用哪块GPU:

                                os.environ['CUDA_VISIBLE_DEVICES']= '0'

                                os.environ['CUDA_VISIBLE_DEVICES']='0,1'

                                使用上下文管理器,手动指定session在哪块gpu上执行

                                with tf.device('/gpu:0'):

                                       ... ... ... ...

                                      with tf.session(config = tf.configProto(log_device_placement=True)) as sess:

                                            print sess.run()

                        参考网址:    https://blog.csdn.net/dcrmg/article/details/79091941

                                             https://www.cnblogs.com/ywheunji/p/11390219.html  

    4) tf.session.run():       

                        Tf.Session.run(  fetches,  feed_dict=None):

                         Fetches:    是个list,里面包括了我们想要输出的一个或者多个graph元素,tensor/ sparse tensor、operation等,见以上文献例子。

                         Feed_dict:   是个dict,里面包括需要输入的参数名称和实际参数传入的key_value 对。

                        这个method运行一步TensorFlow计算,通过运行需要的graph 片段来执行每个操作和计算fetches里的每个Tensor,用feed_dict里面的值替换相应的输入值

    5) tf.session() as sess:     和      tf.session().as_default() as sess:     

    g = tf.graph()                          #新建graph
    session = tf.session(graph = g)  #将新建的graph加载到session
    with g.as_default()      #必须为当前指定default graph(因为可能有多个graph存在)

           5.1) with session.as_default() as sess:

    g_model =  tf.Graph()
    g_session = tf.Session(graph = g_model)
    
    with g_model.as_default() as g:
     with g_session.as_default() as sess:
          c = tf.constant(1)
          print(sess.run(c))
          print(tf.get_default_session())
          print(tf.get_default_graph())
    print(tf.get_default_graph())
    print(sess.run(c))
    sess.close()
    print(sess.run(c))

         print results:

     with之外 sess.close 之前 sess.run(c): 1  也就是说sess没有自动关闭; 手动sess.close之后,sess.run(c)就error了,而get_default_session只能在with内有效,with外是None.

         5.2)  tf.session() as sess:

    with g_Model.as_default() as g:
      with g_session as sess:
           c=tf.constant(1)
          print(sess.run(c))
          print(tf.get_default_session())
          print(tf.get_default_graph())
    print(tf.get_default_graph())
    print(sess.run(c))

     print results:

    可见:没有session.as_default(), with 以外sess自动关闭,sess.run(c)就报错

    以上是关于session 和graph的总结,感谢参考文档的作者,这里作个学习笔记,小白一枚,欢迎大家指正,叩谢~~

        

     

     

  • 相关阅读:
    [LeetCode] 143. 重排链表
    [LeetCode] 342. 4的幂
    [LeetCode] 1744. 你能在你最喜欢的那天吃到你最喜欢的糖果吗?
    [LeetCode] 148. 排序链表
    [LeetCode] 525. 连续数组
    [LeetCode] 160. 相交链表
    [LeetCode] 134. 加油站
    [LeetCode] 474. 一和零
    CentOS 升级 OpenSSH
    AWS 证书取消挂靠
  • 原文地址:https://www.cnblogs.com/zzc-Andy/p/12108385.html
Copyright © 2011-2022 走看看