zoukankan      html  css  js  c++  java
  • TensorFlow学习之graph和session

    引言:

             按照我的理解,graph相当于一块空白的面包板(电路板,可以在上面插电路元件设计电路,供电调试),上面搭建了许多子模块,自顶向下构成一整个功能电路,整个模块有输入输出端口,里面的子模块或者部分子模块

    构成的模块也有输入和输出端口。搭建好的这个电路板就是 'graph',而测试(执行)相应模块都需要开一个session,上面有很多模块可以调试,也就是说一个graph可以开多个session来执行各自模块的输入输出过程。借网上dao来的图作个说明:

       

     

     接下来看一小段代码:

    参考链接:https://blog.csdn.net/vinceee__/article/details/88075451   https://blog.csdn.net/vinceee__/article/details/88075451

    注:以下伪代码仅帮助理解

     1 val = tf.variable(initilizer, ...)  #变量定义
     2 const_1 = tf.constant()          #常量定义
     3 data =3 #定义
     4 op = tf.assign(val, const_1) #定义赋值操作
     5 c = tf.placeholder(dtype= ,shape = [])  #placeholder占位符定义
     6 #使用之前定义的变量, 需要定义一个初始化操作init, sess开始后执行变量初始化
     7 init = tf.global_varibales_initializer()#全局变量初始化操作
     8  with tf.session() as sess:
     9        sess.run(init)   #执行初始化操作
    10        sess.run(op)    #执行操作: 将常量const_1赋值给变量val
    11        print(sess.run(val))
    12        print(sess.run(c, feed_dict = {c:data})#将data值传递到c这个占位符,执行打印 c

    以上代码没有新建graph,默认使用的是TensorFlow的全局的defaultgraph.如果需要使用自己定义的graph,则使用with tf.graph().as_default()来取代默认的全局计算图  

    一、GRAPH

     1、有关graph

    计算图graph 由许多个node 组成,node又由name(名称),op(操作),input(输入),attrs(属性:dtype,shape, size, value... ...构成),对于tf.constant()函数只会产生一个node,但是对于tf.varibale(initialize, name),其中生成一个initializer初始化器,一共会产生三个node:

    1)variable:变量维护(不存放实际的值)

    2) Varibale/assign: 变量分配

    3) variable/read: 变量读取使用

    使用变量时需要进行变量初始化:生成初始化器对象,再在sess中执行init.

    2、graph的使用

     1) Tf.get_default_graph():    获得当前默认graph

     2) With  Tf.Graph().as_default():  tf.graph()创建一个新的graph,并通过as_default()设置为上下文中的默认图(局部的),即with下面所有的定义都是在这个新建的 (局部的)graph中。

        With之外另外的graph作为默认,get_default_graph可以获得当前的图。

     3) tensorflow 有一个默认的全局graph, 不用定义可直接使用,若需要另外的graph则需要新建(tf.graph()在当前使用with范围内设置为default)

    二、Session

     1、有关session

           tf.session 是用来运行TensorFlow操作的类,一个session对象封装了操作执行对象的环境,在这个环境下才可以对tensor对象进行计算,tensor对象不可直接进行计算操作,同时也会负责分配计算资源和变量存放:

    1)步骤: build a graph       #构建一个计算图(里面有操作、tensor对象)

                   launch the graph in a session  #基于以上创建的graph创建一个session

                   evaluate the tensor         #执行计算

    2) session: 包括一些资源,例如variables, queues, readers 当我们不再需要这些资源时候,我们需要释放这些资源,释放资源可以通过执行sess.close(),或者将session作为上下文管理器使用:

    with tf.session() as sess:

           .... ....

           sess.run()

    执行到with范围以外,session资源释放

    3) configProto protocol buffer 公开session的各种配置选项。

            ConfigProto详解:

            

                tf.configProto函数用在创建session的时候对session进行参数配置。

                3.1)Tf.ConfigProto(log­_device_placement=True)

                        log­_device_placement=True,可以获取到operations 和tensor 被指派到哪个设备(几号CPU/GPU)上运行,会在终端打印出各项操作在哪个设备上面运行。

                3.2) Tf.ConfigProto(allow_soft_placement=True)

                     在tf中,通过“with tf.device('cpu:0'):”来手动设置操作运行的设备。如果手动设置的设备不存在或者不可用,就会导致程序等待或者异常,为了防止这种情况,allow_soft_placement=True,可以允许tf自动选择一个存在并可用的设备

                     来运行操作。

               3.3) 限制GPU的资源使用:

                        为了加快运行效率,tf在初始化时候会尝试分配所有可用GPU显存资源给自己,如果在多人使用的服务器上工作会导致别人无法正常使用GPU。

                       tf提供了两种控制GPU资源使用的方法:

                      3.3.1) 设置动态申请显存,需要多少就申请多少

                               

    sess_config = tf.configProto(log_device_placement=False, allow_soft_placement = True)
    sess_config.gpu_options.allow_growth = True
    session = tf.session(config = sess_config)

                      3.3.2) 限制GPU的使用率,如下两种方法:

                        

    sess_config = tf.configProto(log_device_placement= Fa;se ,allow_soft_placementt= True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.4#占用40%显存
    session = tf.session(config = sess_config)
    
    
    gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction= 0.4)
    sess_config = tf.configProto(gpu_options = gpu_options)

                    3.3.3) 设置使用哪块GPU:

                                os.environ['CUDA_VISIBLE_DEVICES']= '0'

                                os.environ['CUDA_VISIBLE_DEVICES']='0,1'

                                使用上下文管理器,手动指定session在哪块gpu上执行

                                with tf.device('/gpu:0'):

                                       ... ... ... ...

                                      with tf.session(config = tf.configProto(log_device_placement=True)) as sess:

                                            print sess.run()

                        参考网址:    https://blog.csdn.net/dcrmg/article/details/79091941

                                             https://www.cnblogs.com/ywheunji/p/11390219.html  

    4) tf.session.run():       

                        Tf.Session.run(  fetches,  feed_dict=None):

                         Fetches:    是个list,里面包括了我们想要输出的一个或者多个graph元素,tensor/ sparse tensor、operation等,见以上文献例子。

                         Feed_dict:   是个dict,里面包括需要输入的参数名称和实际参数传入的key_value 对。

                        这个method运行一步TensorFlow计算,通过运行需要的graph 片段来执行每个操作和计算fetches里的每个Tensor,用feed_dict里面的值替换相应的输入值

    5) tf.session() as sess:     和      tf.session().as_default() as sess:     

    g = tf.graph()                          #新建graph
    session = tf.session(graph = g)  #将新建的graph加载到session
    with g.as_default()      #必须为当前指定default graph(因为可能有多个graph存在)

           5.1) with session.as_default() as sess:

    g_model =  tf.Graph()
    g_session = tf.Session(graph = g_model)
    
    with g_model.as_default() as g:
     with g_session.as_default() as sess:
          c = tf.constant(1)
          print(sess.run(c))
          print(tf.get_default_session())
          print(tf.get_default_graph())
    print(tf.get_default_graph())
    print(sess.run(c))
    sess.close()
    print(sess.run(c))

         print results:

     with之外 sess.close 之前 sess.run(c): 1  也就是说sess没有自动关闭; 手动sess.close之后,sess.run(c)就error了,而get_default_session只能在with内有效,with外是None.

         5.2)  tf.session() as sess:

    with g_Model.as_default() as g:
      with g_session as sess:
           c=tf.constant(1)
          print(sess.run(c))
          print(tf.get_default_session())
          print(tf.get_default_graph())
    print(tf.get_default_graph())
    print(sess.run(c))

     print results:

    可见:没有session.as_default(), with 以外sess自动关闭,sess.run(c)就报错

    以上是关于session 和graph的总结,感谢参考文档的作者,这里作个学习笔记,小白一枚,欢迎大家指正,叩谢~~

        

     

     

  • 相关阅读:
    解决UITableView中Cell重用机制导致内容出错的方法总结
    Hdu 1052 Tian Ji -- The Horse Racing
    Hdu 1009 FatMouse' Trade
    hdu 2037 今年暑假不AC
    hdu 1559 最大子矩阵
    hdu 1004 Let the Balloon Rise
    Hdu 1214 圆桌会议
    Hdu 1081 To The Max
    Hdu 2845 Beans
    Hdu 2955 Robberies 0/1背包
  • 原文地址:https://www.cnblogs.com/zzc-Andy/p/12108385.html
Copyright © 2011-2022 走看看