zoukankan      html  css  js  c++  java
  • tensorflow计算图

    tensorflow计算图

          计算图是对有向图的表示,主要包含点和边;tensorflow使用计算图计算,计算图的点对应于ops,variables,constant,placeholder等,边对应于Tensors。因此tensorflow主要包含两个部分:构建计算图和runtime运行计算图。

     为什么要用计算图?

    1. 并行化,因为计算图是对计算的一种抽象,点之间的关系取决其依赖关系。因此,互相不依赖的计算可以并行计算,在多集群环境下可以进行分布式计算。
    2. 可移植性性,因为图是一种语言无关的表示方式,tensorflow 中使用protobuf来存储图,可以使用C++,python,jave等语言来解析图。 

    总结一下,tensorflow 中要进行计算主要进行两个步骤: 1. 构建graph; 2. session evaluate tensor

    假如要实现一个类似tensorflow框架,需要如何操作?

    • node 节点的实现:
    class Node():
        def __init__(self, input_nodes):
            self.input_nodes=input_nodes
            self.output=Node
        
        def forward(self):
            pass
        def backward(self):
            pass
    class add(Node):
        def forward(self, a,b):
            return a+b
        def backward(self,upstream_grad):
            NotImpl
    
    class multiple(Node):
        def forward(self, a,b):
            return a*b
        def backward(self,upstream_grad):
            NotImpl
    • graph 实现
    class Graph():
        def __init__(self):
            self.nodes = []
        def as_default(self):
            global _default_graph
    
        def add_node(self,node):
            _default_graph.append(node)
    • session 实现
    # session 接受要run的operation,要得到输出要拓扑排序所有的node,在session run的时候,按正确顺序执行。 可以表示为DAG(directed acyclic graph) 
    #  将要执行的node作为unvisited node加入栈中,利用深度优先搜索的方式递归遍历所有的node,当node没有其他输入时,将node标记为visited出栈。出栈的顺序就是拓扑排序。 
    
    class Session():
        def run(self, node, feed_dict={}):
            nodes_sorted=topology_sort(node)
            for node in nodes_sorted:
                if type(node)==Placeholder:
                    node.output=feed_dict[node]
                elif type(node)==Variable or type(node)==Constant:
                    node.output=node.value
                else:
                    inputs=[node.output for node in node.input_nodes]
                    node.output=node.forward(*inputs)
            return node.output          

    完整的code:

    import numpy as np
    
    class Graph():
      def __init__(self):
        self.operations = []
        self.placeholders = []
        self.variables = []
        self.constants = []
    
      def as_default(self):
        global _default_graph
        _default_graph = self
    
    class Operation():
      def __init__(self, input_nodes=None):
        self.input_nodes = input_nodes
        self.output = None
        
        # Append operation to the list of operations of the default graph
        _default_graph.operations.append(self)
    
      def forward(self):
        pass
    
      def backward(self):
        pass
    
    class BinaryOperation(Operation):
      def __init__(self, a, b):
        super().__init__([a, b])
    
    class add(BinaryOperation):
      """
      Computes a + b element-wise
      """
      def forward(self, a, b):
        return a + b
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class multiply(BinaryOperation):
      """
      Computes a * b, element-wise
      """
      def forward(self, a, b):
        return a * b
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class divide(BinaryOperation):
      """
      Returns the true division of the inputs, element-wise
      """
      def forward(self, a, b):
        return np.true_divide(a, b)
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class matmul(BinaryOperation):
      """
      Multiplies matrix a by matrix b, producing a * b
      """
      def forward(self, a, b):
        return a.dot(b)
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class Placeholder():
      def __init__(self):
        self.value = None
        _default_graph.placeholders.append(self)
    
    class Constant():
      def __init__(self, value=None):
        self.__value = value
        _default_graph.constants.append(self)
    
      @property
      def value(self):
        return self.__value
    
      @value.setter
      def value(self, value):
        raise ValueError("Cannot reassign value.")
    
    class Variable():
      def __init__(self, initial_value=None):
        self.value = initial_value
        _default_graph.variables.append(self)
    
    def topology_sort(operation):
        ordering = []
        visited_nodes = set()
    
        def recursive_helper(node):
          if isinstance(node, Operation):
            for input_node in node.input_nodes:
              if input_node not in visited_nodes:
                recursive_helper(input_node)
    
          visited_nodes.add(node)
          ordering.append(node)
    
        # start recursive depth-first search
        recursive_helper(operation)
    
        return ordering
    
    class Session():
      def run(self, operation, feed_dict={}):
        nodes_sorted = topology_sort(operation)
    
        for node in nodes_sorted:
          if type(node) == Placeholder:
            node.output = feed_dict[node]
          elif type(node) == Variable or type(node) == Constant:
            node.output = node.value
          else:
            inputs = [node.output for node in node.input_nodes]
            node.output = node.forward(*inputs)
    
        return operation.output
    
    
    
    
    
    
    import tf_api as tf
    
    # create default graph
    tf.Graph().as_default()
    
    # construct computational graph by creating some nodes
    a = tf.Constant(15)
    b = tf.Constant(5)
    prod = tf.multiply(a, b)
    sum = tf.add(a, b)
    res = tf.divide(prod, sum)
    
    # create a session object
    session = tf.Session()
    
    # run computational graph to compute the output for 'res'
    out = session.run(res)
    print(out)
    
    
  • 相关阅读:
    layui的模块化和非模块化使用
    layui实现类似于bootstrap的模态框功能
    ajax下载文件
    【IDEA】IDEA中maven项目pom.xml依赖不生效解决
    主-主数据库系统架构
    MyEclipse x.x各版本终极优化配置指南
    Cactus入门
    有史以来最出彩的编程语言名字
    安卓开发20:动画之Animation 详细使用-主要通过java代码实现动画效果
    第一次讲课总结
  • 原文地址:https://www.cnblogs.com/fanhaha/p/12341575.html
Copyright © 2011-2022 走看看