zoukankan      html  css  js  c++  java
  • tensorflow计算图

    tensorflow计算图

          计算图是对有向图的表示,主要包含点和边;tensorflow使用计算图计算,计算图的点对应于ops,variables,constant,placeholder等,边对应于Tensors。因此tensorflow主要包含两个部分:构建计算图和runtime运行计算图。

     为什么要用计算图?

    1. 并行化,因为计算图是对计算的一种抽象,点之间的关系取决其依赖关系。因此,互相不依赖的计算可以并行计算,在多集群环境下可以进行分布式计算。
    2. 可移植性性,因为图是一种语言无关的表示方式,tensorflow 中使用protobuf来存储图,可以使用C++,python,jave等语言来解析图。 

    总结一下,tensorflow 中要进行计算主要进行两个步骤: 1. 构建graph; 2. session evaluate tensor

    假如要实现一个类似tensorflow框架,需要如何操作?

    • node 节点的实现:
    class Node():
        def __init__(self, input_nodes):
            self.input_nodes=input_nodes
            self.output=Node
        
        def forward(self):
            pass
        def backward(self):
            pass
    class add(Node):
        def forward(self, a,b):
            return a+b
        def backward(self,upstream_grad):
            NotImpl
    
    class multiple(Node):
        def forward(self, a,b):
            return a*b
        def backward(self,upstream_grad):
            NotImpl
    • graph 实现
    class Graph():
        def __init__(self):
            self.nodes = []
        def as_default(self):
            global _default_graph
    
        def add_node(self,node):
            _default_graph.append(node)
    • session 实现
    # session 接受要run的operation,要得到输出要拓扑排序所有的node,在session run的时候,按正确顺序执行。 可以表示为DAG(directed acyclic graph) 
    #  将要执行的node作为unvisited node加入栈中,利用深度优先搜索的方式递归遍历所有的node,当node没有其他输入时,将node标记为visited出栈。出栈的顺序就是拓扑排序。 
    
    class Session():
        def run(self, node, feed_dict={}):
            nodes_sorted=topology_sort(node)
            for node in nodes_sorted:
                if type(node)==Placeholder:
                    node.output=feed_dict[node]
                elif type(node)==Variable or type(node)==Constant:
                    node.output=node.value
                else:
                    inputs=[node.output for node in node.input_nodes]
                    node.output=node.forward(*inputs)
            return node.output          

    完整的code:

    import numpy as np
    
    class Graph():
      def __init__(self):
        self.operations = []
        self.placeholders = []
        self.variables = []
        self.constants = []
    
      def as_default(self):
        global _default_graph
        _default_graph = self
    
    class Operation():
      def __init__(self, input_nodes=None):
        self.input_nodes = input_nodes
        self.output = None
        
        # Append operation to the list of operations of the default graph
        _default_graph.operations.append(self)
    
      def forward(self):
        pass
    
      def backward(self):
        pass
    
    class BinaryOperation(Operation):
      def __init__(self, a, b):
        super().__init__([a, b])
    
    class add(BinaryOperation):
      """
      Computes a + b element-wise
      """
      def forward(self, a, b):
        return a + b
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class multiply(BinaryOperation):
      """
      Computes a * b, element-wise
      """
      def forward(self, a, b):
        return a * b
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class divide(BinaryOperation):
      """
      Returns the true division of the inputs, element-wise
      """
      def forward(self, a, b):
        return np.true_divide(a, b)
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class matmul(BinaryOperation):
      """
      Multiplies matrix a by matrix b, producing a * b
      """
      def forward(self, a, b):
        return a.dot(b)
    
      def backward(self, upstream_grad):
        raise NotImplementedError
    
    class Placeholder():
      def __init__(self):
        self.value = None
        _default_graph.placeholders.append(self)
    
    class Constant():
      def __init__(self, value=None):
        self.__value = value
        _default_graph.constants.append(self)
    
      @property
      def value(self):
        return self.__value
    
      @value.setter
      def value(self, value):
        raise ValueError("Cannot reassign value.")
    
    class Variable():
      def __init__(self, initial_value=None):
        self.value = initial_value
        _default_graph.variables.append(self)
    
    def topology_sort(operation):
        ordering = []
        visited_nodes = set()
    
        def recursive_helper(node):
          if isinstance(node, Operation):
            for input_node in node.input_nodes:
              if input_node not in visited_nodes:
                recursive_helper(input_node)
    
          visited_nodes.add(node)
          ordering.append(node)
    
        # start recursive depth-first search
        recursive_helper(operation)
    
        return ordering
    
    class Session():
      def run(self, operation, feed_dict={}):
        nodes_sorted = topology_sort(operation)
    
        for node in nodes_sorted:
          if type(node) == Placeholder:
            node.output = feed_dict[node]
          elif type(node) == Variable or type(node) == Constant:
            node.output = node.value
          else:
            inputs = [node.output for node in node.input_nodes]
            node.output = node.forward(*inputs)
    
        return operation.output
    
    
    
    
    
    
    import tf_api as tf
    
    # create default graph
    tf.Graph().as_default()
    
    # construct computational graph by creating some nodes
    a = tf.Constant(15)
    b = tf.Constant(5)
    prod = tf.multiply(a, b)
    sum = tf.add(a, b)
    res = tf.divide(prod, sum)
    
    # create a session object
    session = tf.Session()
    
    # run computational graph to compute the output for 'res'
    out = session.run(res)
    print(out)
    
    
  • 相关阅读:
    hdu acm 2844 Coins 解题报告
    hdu 1963 Investment 解题报告
    codeforces 454B. Little Pony and Sort by Shift 解题报告
    广大暑假训练1 E题 Paid Roads(poj 3411) 解题报告
    hdu acm 2191 悼念512汶川大地震遇难同胞——珍惜现在,感恩生活
    hdu acm 1114 Piggy-Bank 解题报告
    poj 2531 Network Saboteur 解题报告
    数据库范式
    ngnix 配置CI框架 与 CI的简单使用
    Vundle的安装
  • 原文地址:https://www.cnblogs.com/fanhaha/p/12341575.html
Copyright © 2011-2022 走看看