zoukankan      html  css  js  c++  java
  • Tensorflow计算模型 —— 计算图

    转载自:http://blog.csdn.net/john_xyz/article/details/69053626

      Tensorflow是一个通过计算图的形式来表述计算的编程系统,计算图也叫数据流图,可以把计算图看做是一种有向图,Tensorflow中的每一个计算都是计算图上的一个节点,而节点之间的边描述了计算之间的依赖关系。

    计算图的使用

      在tensorflow程序中,系统会维护一个默认的计算图,通过tf.get_default_graph()函数可以获取当前默认的计算图,为了向默认的计算图中添加一个操作,我们只需要简单的调用一个函数:

    c = tf.constant(3.0) 
    assert c.graph == tf.get_default_graph()

      除了使用默认的计算图,Tensorflow支持通过tf.Graph()函数来生成新的计算图,不同计算图的张量和运算都不会共享,使用tf.Graph.as_default()覆盖当前的默认图。

    g = tf.Graph()
    with g.as_default():
        c = tf.constant(3.0)
        assert c.graph is g
    #coding:utf-8
    import tensorflow as tf
    
    g1 = tf.Graph()
    with g1.as_default():
        # 在图g1中定义初始变量c, 并设置初始值为0
        v = tf.get_variable("v", shape=[1], initializer = tf.zeros_initializer(dtype=tf.float32))
    
    g2 = tf.Graph()
    with g2.as_default():
        # 在图g1中定义初始变量c, 并设置初始值为1
        v = tf.get_variable("v", shape=[1], initializer = tf.ones_initializer(dtype=tf.float32))
    
    with tf.Session(graph=g1) as sess:
        sess.run(tf.global_variables_initializer())
        with tf.variable_scope('', reuse=True):
            # 输出值为0
            print sess.run(tf.get_variable("v"))
    
    with tf.Session(graph=g2) as sess:
        sess.run(tf.global_variables_initializer())
        with tf.variable_scope('', reuse=True):
           # 输出值为1
           print sess.run(tf.get_variable('v'))

      上面的代码产生了两个计算图,当运行不同的计算图时,变量v的值是不一样的。同时,计算图Graph通过tf.Graph.device()函数来制定运行计算图的设备, 下图定义的程序可以将加法计算跑在GPU上

    g = tf.Graph()
    # 指定计算运行的设备
    with g.device('/gpu:0'):
        result = a + b

      在一个计算图中,可以通过集合(collection)来管理不同类别的资源,一个计算图Graph实例支持任意数量的 name定义的collection, 当构建一个计算图时,collections可以存储一组相关的对象。

      例如:tf.Variables使用一个collection (named tf.GraphKeys.GLOBAL_VARIABLES)存储所有的变量,当构建计算图的时候。可以通过tf.add_to_collection()函数将资源加入一个collection中,然后通过tf.get_collection获取一个集合里面的所有资源。

           tensorflow中自动管理了一些常用的集合,如下表:

    集合名称

    集合内容 

    使用场景

    tf.GraphKeys.VARIABLES

    所有变量

    持久化tensorflow模型

    tf.GraphKeys.TRAINABLE_VARIABLES

    可学习的变量(一般指神经网络中的参数)

    模型训练、生成模型可视化内容

    tf.GraphKeys.SUMMARIES

    日志生成相关的张量

    tensorflow计算可视化

    tf.GraphKeys.QUEUE_RUNNERS

    处理输入的QueueRunner

    输入处理

    tf.GraphKeys.MOVING_AVERAGE_VARIABLES

    所有计算了滑动平均值的变量

    计算变量的滑动平均值

  • 相关阅读:
    SPSS分析技术:CMH检验(分层卡方检验);辛普森悖论,数据分析的谬误
    揭秘10个大数据神话 为你排除几个误区
    SPSS统计分析案例:无空白列重复正交试验设计方差分析
    SPSS统计分析案例:无空白列重复正交试验设计方差分析
    单点登录系统和CAS的简介
    多线程
    HTTP 400错误--请求无效
    前端框架bootstrap(响应式布局)入门
    MQ(队列消息的入门)
    ActiveMQ下载与安装(Linux环境下进行)
  • 原文地址:https://www.cnblogs.com/hejunlin1992/p/8270569.html
Copyright © 2011-2022 走看看