计算图和tf.function
官网
Introduction to graphs and tf.function | TensorFlow Core
1. 概述
主要涉及Tensorflow和Keras的处理逻辑,即如何通过对代码进行简单修改来获取计算图,如何存储和表示计算图,以及如何使用它们来加速模型。
计算图是什么?
计算图是一种数据结构,包含了一系列tf.Operation
,这些tf.Operation被表示为计算节点;还包含了一系列tf.Tensor
,Tensor被表示为数据节点在计算节点间传输。
Graphs are data structures that contain a set of tf.Operation objects, which represent units of computation; and tf.Tensor objects, which represent the units of data that flow between operations.
计算图的优点
相比Eager Execution,计算图可以更好的移植到Python以外的语言并且更高效。可以让训练更快、进行并行训练、在不同设备上进行分布式训练
2. 利用计算图的优势
可以使用tf.function
来创建和使用计算图,tf.function可以将一个普通的Python函数封装成Tensorflow Function
。Function
可以像使用普通函数一样被调用,但其底层是封装了多个tf.Graph
API来实现的。
import tensorflow as tf
import timeit
from datetime import datetime
# Define a Python function.
def a_regular_function(x, y, b):
x = tf.matmul(x, y)
x = x + b
return x
# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)
# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)
orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)
也可以使用@tf.function
注解来标注计算图,并且可以让其内部调用的函数也使用计算图
def inner_function(x, y, b):
x = tf.matmul(x, y)
x = x + b
return x
# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
y = tf.constant([[2.0], [3.0]])
b = tf.constant(4.0)
return inner_function(x, y, b)
# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
将Python函数转换为计算图
在实现的函数中,会混合使用tensorflow的操作和Python的操作(如if、break、return等),tensorflow的操作可以简单的被计算图捕捉,但Python的逻辑需要用一个单独的库AutoGraph(tf.autograph
)来实现转换。
def simple_relu(x):
if tf.greater(x, 0):
return x
else:
return 0
tf_simple_relu = tf.function(simple_relu)
# This is the graph-generating output of AutoGraph.
# 计算图方法
print(tf.autograph.to_code(simple_relu))
# This is the graph itself.
# 计算图的结构
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
多态(Polymorphism):一个函数,多个计算图
一个tf.Graph
只能处理特定的参数,在调用Function
时如果使用新的dtype
和shape
的数据,Function
都会创建新的tf.Graph
来处理。因此dtype
和shape
可以被看作是计算图的签名(input signature)。Function将计算图和其对应的签名包装在ConcreteFunction
中,重复调用会复用生成好的计算图。
一个Function
对应了多个Graph
,所以Function可以使用不同的参数进行调用,因此Function是多态的。
3.如何使用tf.function
默认情况下Function
会使用计算图,如果想要使用Eager Execution就需要配置tf.config.run_functions_eagerly(True),并在使用完成后将配置修改回来。
Eager和Graph的对比
在使用Graph时,结果只被打印了一次。
因为Tensorflow会先执行一次方法,并使用一种被称为“tracing”的技术来记录操作从而生成Graph,但print方法并没有被记录在Graph中。之后的调用都使用Graph来执行,所以整个过程只打印了一次。如果想在Graph中每次也进行打印,需要使用tf.print
。
@tf.function
def get_MSE(y_true, y_pred):
print("Calculating MSE!")
sq_diff = tf.pow(y_true - y_pred, 2)
return tf.reduce_mean(sq_diff)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
====>
Calculating MSE!
而Eager使用原生的Python调用,所以会打印三次
# 设置是全局的,使用完后一定要设置回来
tf.config.run_functions_eagerly(True)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
tf.config.run_functions_eagerly(False)
===>
Calculating MSE!
Calculating MSE!
Calculating MSE!
tf.function的最佳实践
@tf.function来使用计算图,有以下一些技巧:
- 经常使用run_functions_eagerly切换Eager和Graph方式,来确认这两种方式在什么时候执行与预期不符合。
- 在Python方法外创建
tf.Variable
,在方法内修改它们。这个规则适用于其他使用了tf.Variable
的对象,如tf.layers
、tf.Model
、tf.optimizer
- 避免方法依赖外部的Python变量。(
tf.Variable
和tf.Karas
相关对象例外) - 尽量使用Tensorflow对象作为入参,如果要使用其他对象务必要小心!
- 将大多数计算都放在方法中,以便计算图可以对其进行优化
4.关于性能提升
由于调用和生成计算图也有开销,所以对于少量计算并不太能感知到性能的提升。对于训练的多个批次,性能确实会有较大提升,可以使用以下方式进行比较
x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)
def power(x, y):
result = tf.eye(10, dtype=tf.dtypes.int32)
for _ in range(y):
result = tf.matmul(x, result)
return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
===>
Eager execution: 1.777665522999996
Graph execution: 0.5308018169999968
权衡
对于某些方法,生成计算图的开销可能比执行还大,这个开销一般会在方法被重复调用的性能提升抵消掉。
任何大型的模型前几次训练批次都会因为Tracing而变得比较慢,tf.function中讨论了如何使用特定入参和tensor参数来避免重复Tracing。
5. Function Tracing是什么?
只要在方法中加入print,每次tracing的时候就会被打印,从而可以知道什么时候进行了tracing。
@tf.function
def a_function_with_python_side_effect(x):
print("Tracing!") # An eager-only side effect.
return x * x + tf.constant(2)
# 如果使用tensor标量可以看到方法只会被tracing一次
# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
==>
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
如果换成用python的原始类型,即使多次调用的参数都是整型,也会进行多次tracing。这种场景经常出现在训练的epoch或超参数!!
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
==>
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)