zoukankan      html  css  js  c++  java
  • tensorflow2学习笔记---Graph和tf.function

    计算图和tf.function

    官网

    Introduction to graphs and tf.function | TensorFlow Core


    1. 概述

    主要涉及Tensorflow和Keras的处理逻辑,即如何通过对代码进行简单修改来获取计算图,如何存储和表示计算图,以及如何使用它们来加速模型。

    计算图是什么?

    计算图是一种数据结构,包含了一系列tf.Operation,这些tf.Operation被表示为计算节点;还包含了一系列tf.Tensor,Tensor被表示为数据节点在计算节点间传输。

    Graphs are data structures that contain a set of tf.Operation objects, which represent units of computation; and tf.Tensor objects, which represent the units of data that flow between operations.

    计算图的优点

    相比Eager Execution,计算图可以更好的移植到Python以外的语言并且更高效。可以让训练更快、进行并行训练、在不同设备上进行分布式训练


    2. 利用计算图的优势

    可以使用tf.function来创建和使用计算图,tf.function可以将一个普通的Python函数封装成Tensorflow FunctionFunction可以像使用普通函数一样被调用,但其底层是封装了多个tf.Graph API来实现的。

    import tensorflow as tf
    import timeit
    from datetime import datetime
    
    # Define a Python function.
    def a_regular_function(x, y, b):
      x = tf.matmul(x, y)
      x = x + b
      return x
    
    # `a_function_that_uses_a_graph` is a TensorFlow `Function`.
    a_function_that_uses_a_graph = tf.function(a_regular_function)
    
    # Make some tensors.
    x1 = tf.constant([[1.0, 2.0]])
    y1 = tf.constant([[2.0], [3.0]])
    b1 = tf.constant(4.0)
    
    orig_value = a_regular_function(x1, y1, b1).numpy()
    # Call a `Function` like a Python function.
    tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
    assert(orig_value == tf_function_value)
    

    也可以使用@tf.function注解来标注计算图,并且可以让其内部调用的函数也使用计算图

    def inner_function(x, y, b):
      x = tf.matmul(x, y)
      x = x + b
      return x
    
    # Use the decorator to make `outer_function` a `Function`.
    @tf.function
    def outer_function(x):
      y = tf.constant([[2.0], [3.0]])
      b = tf.constant(4.0)
    
      return inner_function(x, y, b)
    
    # Note that the callable will create a graph that
    # includes `inner_function` as well as `outer_function`.
    outer_function(tf.constant([[1.0, 2.0]])).numpy()
    

    将Python函数转换为计算图

    在实现的函数中,会混合使用tensorflow的操作和Python的操作(如if、break、return等),tensorflow的操作可以简单的被计算图捕捉,但Python的逻辑需要用一个单独的库AutoGraph(tf.autograph)来实现转换。

    def simple_relu(x):
      if tf.greater(x, 0):
        return x
      else:
        return 0
    
    tf_simple_relu = tf.function(simple_relu)
    
    # This is the graph-generating output of AutoGraph.
    # 计算图方法
    print(tf.autograph.to_code(simple_relu))
    
    # This is the graph itself.
    # 计算图的结构
    print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
    

    多态(Polymorphism):一个函数,多个计算图

    一个tf.Graph只能处理特定的参数,在调用Function时如果使用新的dtypeshape的数据,Function都会创建新的tf.Graph来处理。因此dtypeshape可以被看作是计算图的签名(input signature)。Function将计算图和其对应的签名包装在ConcreteFunction中,重复调用会复用生成好的计算图。

    一个Function对应了多个Graph,所以Function可以使用不同的参数进行调用,因此Function是多态的。


    3.如何使用tf.function

    默认情况下Function会使用计算图,如果想要使用Eager Execution就需要配置tf.config.run_functions_eagerly(True),并在使用完成后将配置修改回来。


    Eager和Graph的对比

    在使用Graph时,结果只被打印了一次。

    因为Tensorflow会先执行一次方法,并使用一种被称为“tracing”的技术来记录操作从而生成Graph,但print方法并没有被记录在Graph中。之后的调用都使用Graph来执行,所以整个过程只打印了一次。如果想在Graph中每次也进行打印,需要使用tf.print

    @tf.function
    def get_MSE(y_true, y_pred):
    	print("Calculating MSE!")
      sq_diff = tf.pow(y_true - y_pred, 2)
      return tf.reduce_mean(sq_diff)
    
    error = get_MSE(y_true, y_pred)
    error = get_MSE(y_true, y_pred)
    error = get_MSE(y_true, y_pred)
    
    ====>
    Calculating MSE!
    

    而Eager使用原生的Python调用,所以会打印三次

    # 设置是全局的,使用完后一定要设置回来
    tf.config.run_functions_eagerly(True)
    error = get_MSE(y_true, y_pred)
    error = get_MSE(y_true, y_pred)
    error = get_MSE(y_true, y_pred)
    tf.config.run_functions_eagerly(False)
    
    ===>
    Calculating MSE!
    Calculating MSE!
    Calculating MSE!
    

    tf.function的最佳实践

    @tf.function来使用计算图,有以下一些技巧:

    • 经常使用run_functions_eagerly切换Eager和Graph方式,来确认这两种方式在什么时候执行与预期不符合。
    • 在Python方法外创建tf.Variable,在方法内修改它们。这个规则适用于其他使用了tf.Variable的对象,如tf.layerstf.Modeltf.optimizer
    • 避免方法依赖外部的Python变量。(tf.Variabletf.Karas相关对象例外)
    • 尽量使用Tensorflow对象作为入参,如果要使用其他对象务必要小心!
    • 将大多数计算都放在方法中,以便计算图可以对其进行优化

    4.关于性能提升

    由于调用和生成计算图也有开销,所以对于少量计算并不太能感知到性能的提升。对于训练的多个批次,性能确实会有较大提升,可以使用以下方式进行比较

    x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)
    
    def power(x, y):
      result = tf.eye(10, dtype=tf.dtypes.int32)
      for _ in range(y):
        result = tf.matmul(x, result)
      return result
    
    print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
    
    power_as_graph = tf.function(power)
    print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
    
    ===>
    Eager execution: 1.777665522999996
    Graph execution: 0.5308018169999968
    

    权衡

    对于某些方法,生成计算图的开销可能比执行还大,这个开销一般会在方法被重复调用的性能提升抵消掉。

    任何大型的模型前几次训练批次都会因为Tracing而变得比较慢,tf.function中讨论了如何使用特定入参和tensor参数来避免重复Tracing。


    5. Function Tracing是什么?

    只要在方法中加入print,每次tracing的时候就会被打印,从而可以知道什么时候进行了tracing。

    @tf.function
    def a_function_with_python_side_effect(x):
      print("Tracing!") # An eager-only side effect.
      return x * x + tf.constant(2)
    
    # 如果使用tensor标量可以看到方法只会被tracing一次
    # This is traced the first time.
    print(a_function_with_python_side_effect(tf.constant(2)))
    # The second time through, you won't see the side effect.
    print(a_function_with_python_side_effect(tf.constant(3)))
    
    ==>
    Tracing!
    tf.Tensor(6, shape=(), dtype=int32)
    tf.Tensor(11, shape=(), dtype=int32)
    

    如果换成用python的原始类型,即使多次调用的参数都是整型,也会进行多次tracing。这种场景经常出现在训练的epoch或超参数!!

    # This retraces each time the Python argument changes,
    # as a Python argument could be an epoch count or other
    # hyperparameter.
    print(a_function_with_python_side_effect(2))
    print(a_function_with_python_side_effect(3))
    
    ==>
    Tracing!
    tf.Tensor(6, shape=(), dtype=int32)
    Tracing!
    tf.Tensor(11, shape=(), dtype=int32)
    
  • 相关阅读:
    JavaScript 原型和原型链 prototype
    javascript dom 表单元素之 radio
    JavaScript Dom 表单元素之 checkbox
    JavaScript DOM 表单元素之 select
    JavaScript-ECMAScript 之模块
    Javascript--ECMAScript 之 this
    Javascript-ECMAscript--Array.prototype.slice() 方法
    JavaScript -ECMAScriopt: Array.prototype.slice.call()详解及转换数组的方法
    JavaScript-ECMASCript apply call bind
    requests的深入刨析及封装调用
  • 原文地址:https://www.cnblogs.com/hikari-1994/p/14664157.html
Copyright © 2011-2022 走看看