本书到目前为止一直都在使用命令式编程,它使用编程语句改变程序状态。考虑下面这段简单的命令式程序。
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
fancy_func(1, 2, 3, 4) # 10
和我们预期的一样,在运行语句e = add(a, b)
时,Python会做加法运算并将结果存储在变量e
中,从而令程序的状态发生改变。类似地,后面的两条语句f = add(c, d)
和g = add(e, f)
会依次做加法运算并存储变量。
虽然使用命令式编程很方便,但它的运行可能很慢。一方面,即使fancy_func
函数中的add
是被重复调用的函数,Python也会逐一执行这3条函数调用语句。另一方面,我们需要保存变量e
和f
的值直到fancy_func
中所有语句执行结束。这是因为在执行e = add(a, b)
和f = add(c, d)
这2条语句之后我们并不知道变量e
和f
是否会被程序的其他部分使用。
与命令式编程不同,符号式编程通常在计算流程完全定义好后才被执行。多个深度学习框架,如Theano和TensorFlow,都使用了符号式编程。通常,符号式编程的程序需要下面3个步骤:
- 定义计算流程;
- 把计算流程编译成可执行的程序;
- 给定输入,调用编译好的程序执行。
下面我们用符号式编程重新实现本节开头给出的命令式编程代码。
def add_str():
return '''
def add(a, b):
return a + b
'''
def fancy_func_str():
return '''
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
'''
def evoke_str():
return add_str() + fancy_func_str() + '''
print(fancy_func(1, 2, 3, 4))
'''
prog = evoke_str()
print(prog)
y = compile(prog, '', 'exec')
exec(y)
输出:
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
以上定义的3个函数都仅以字符串的形式返回计算流程。最后,我们通过compile
函数编译完整的计算流程并运行。由于在编译时系统能够完整地获取整个程序,因此有更多空间优化计算。例如,编译的时候可以将程序改写成print((1 + 2) + (3 + 4))
,甚至直接改写成print(10)
。这样不仅减少了函数调用,还节省了内存。
对比这两种编程方式,我们可以看到以下两点。
-
命令式编程更方便。当我们在Python里使用命令式编程时,大部分代码编写起来都很直观。同时,命令式编程更容易调试。这是因为我们可以很方便地获取并打印所有的中间变量值,或者使用Python的调试工具。
-
符号式编程更高效并更容易移植。一方面,在编译的时候系统容易做更多优化;另一方面,符号式编程可以将程序变成一个与Python无关的格式,从而可以使程序在非Python环境下运行,以避开Python解释器的性能问题。
大部分深度学习框架在命令式编程和符号式编程之间二选一。例如,Theano和受其启发的后来者TensorFlow1.x
使用了符号式编程,Chainer
和它的追随者PyTorch
使用了命令式编程。开发人员在设计Tensorflow2.x
时思考了这个问题:有没有可能既得到命令式编程的好处,又享受符号式编程的优势?开发者们认为,用户应该用纯命令式编程进行开发和调试;当需要产品级别的计算性能和部署时,用户可以将大部分命令式程序转换成符号式程序来运行。Tensorflow通过提供静态图转换器tf.function
,实现对两种编程方式的支持。在 Tensorflow
在不使用静态图转换器tf.function
时,用户编写的python
函数默认会采用命令式编程逐行执行,符合python
编程的直觉,便于调试,但因为框架不能获得完整的静态运算图,不能进行优化,且动态图不能序列化。官方由此推出了静态图转换器tf.function
,其作用在python_function
后会将这个函数"编译"成一个运算图,接受input_tensors
为输入并用图执行的方式计算结果,可以加速函数执行,并且可以被序列化后供任何其他语言(如C++和Java)调用,这样用户只需要在使用动态图运算编写和测试完毕函数之后使用tf.function装饰一下就能够获得静态图的所有优点。
tf.function
可以定义一个Tensorflow
操作,既可以命令式的执行它,
@tf.function
def add(a, b):
return a+b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
[2., 2.]], dtype=float32)>
也可以在图中执行它,并求其梯度,
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
也可以定义嵌套定义(当然,在实际使用中,可以直接在顶层定义,会自动对子图进行转换),
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
[3., 3.],
[3., 3.]], dtype=float32)>
Python
的动态类型意味着用户可以传递多种类型的参数,这也许会导致函数产生不同的行为。
在另一方面,Tensorflow
的静态图需要确定的数据类型和维度。tf.function
通过retracing
函数调用来弥补这一差距,并在必要的时候产生正确的计算图。tf.function
大多数微妙的行为都产生自retracing
行为。
我们可以用不同的参数调用同一函数来观察retracing
行为:
# Functions are polymorphic
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)
如果希望控制 tracing
行为,可以用如下方式操作:
- 创建新的
tf.function
。分离tf.fucntion
对象,保证没有共享的计算图引用。 - 使用
get_concrete_function
方法,得到特定的计算图。 - 声明
input_signature
当调用tf.function
时,仅跟踪与输入签名一致的调用。
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Obtaining concrete trace Tracing with Tensor("a:0", dtype=string) Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string) Using a concrete trace with incompatible types will throw an error Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "", line 8, in assert_raises yield File "", line 8, in double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_87 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_87]
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception
<class 'ValueError'>:
Traceback (most recent call last):
File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
yield
File "<ipython-input-9-9939c82c1507>", line 9, in <module>
next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
inputs: (
tf.Tensor(
[[1 2]
[3 4]], shape=(2, 2), dtype=int32))
input_signature: (
TensorSpec(shape=(None,), dtype=tf.int32, name=None))
多态函数 tf.function
会缓存之前追踪行为触发生成过的具体函数。缓存的键由传入的参数确定,对于 tf.Tensor
参数而言,是其维度和类型,而对于 Python
元语,是其值。对于其它 Python类型,使用对象 id,即对每个不同的类实例都会触发独立的追踪行为,并生成相应的静态图。
通常,Python
参数被用作超参数,如 num_layers=10
、training=True
以及nonlinearity='relu'
。此时,Python
参数的改变触发 retrace
行为来构建新的计算图是合理的。然而,在另一些情况下,Python
参数并不改变计算图,是不需要触发 retrace
重新构建计算图的。例如,在训练过程中控制步数,AutoGraph 会自动动态展开,因此传入不同的步数,其生成图是一致的,这时如果触发多个 trace
生成同样的计算图,是很低效的。
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = {}".format(num_steps))
for _ in tf.range(num_steps):
train_one_step()
train(num_steps=10)
train(num_steps=20)
Tracing with num_steps = 10
Tracing with num_steps = 20
一种简单的绕过方式是,将参数转换为 Tensor,这样不改变 shape,就不会触发生成计算图。
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)
通常,Python
的附带效应(print
或改变对象)仅发生在 tracing
行为中。何时应该触发 tf.function
的附带效应呢?
经验上推荐仅使用附带效应来 debug trace
行为。其它情况则建议使用 Tensorflow
运算如 tf.Variable.assign
、tf.print
和tf.summary
来在 Tensorflow runtime
保证代码跟踪和执行。帮助调试的最佳实践是使用函数式风格。
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2
如果想要在每次调用 tf.function
时执行 Python
代码,tf.py_function
提供了这种方式的支持。使用 tf.py_function
的一个缺点是性能,不能很好的工作在分布式环境下(多GPU/TPU)。由于 tf.py_function
需要被可微的连入计算图,输入和输出都会被转换为 tf.Tensor
。
external_list = []
def side_effect(x):
print('Python side effect')
external_list.append(x)
@tf.function
def f(x):
tf.py_function(side_effect, inp=[x], Tout=[])
f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1
Python side effect
Python side effect
Python side effect
许多 Python 特性,如生成器和迭代器,依赖于 Python 运行时跟踪其状态。通常,这些构件在动态图模式下工作正常,但由于 tracing
行为,在 tf.function
内会发生预期外行为。
例如,迭代器状态被视为一种 Python 附带效应,因此只在 tracing
时触发一次。
external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
external_var.assign_add(next(iterator))
tf.print("Value of external_var:", external_var)
iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0
如果一个迭代器生成和消费完全在 tf.function
中,它会正确地工作,但会产生巨大地计算图,这也许不符合你的预期,更严重的是,对于以 Python List 表示的内存中大型数据集,相应的大型计算图也并不能带来性能提升。
如果想要在 Python
数据集上迭代,最安全的方式是使用 tf.data.Dataset
封装,并以 for x in y
方式遍历。AutoGraph 对 for
循环中 y
是一个 tf.Tensor
或 tf.data.Dataset
有特殊支持。
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph
当封装 Python/Numpy
数据为 tf.data.Dataset
时,要注意 tf.data.Dataset.from_generator
与 tf.data.Dataset.from_tensors
的区别。前者保持数据在 Python runtime
中,通过 tf.py_function
取数,会造成一定的性能瓶颈。而后者会复制一份到 Tensorflow runtime
成为计算图的一个 tf.constant()
节点,会占用更多的内存。
从文件中读数据,如 TFRecordDataset/CsvDataset
等,是读取数据最高效的方式,Tensorflow
可以自行管理异步读取和预加载数据,而不需要 Python
的参与。
作为编程模型,tf.function
一个非常吸引人的特性是,在通常的数据流图之上,为运行时环境提供了更多关于代码预期行为的信息。
例如,当写代码是多次读写同一变量,数据流图也许不会编码原本预期的操作顺序。在 tf.function
中,则会依照在 Python
代码中的声明顺序消除执行顺序的歧义。这使得 tf.function
支持状态操作,可以复制动态图模式的语义。
这意味着不再需要手动添加控制依赖,而可以交由 tf.function
自动添加控制依赖。
# Automatic control dependencies
a = tf.Variable(1.0)
b = tf.Variable(2.0)
@tf.function
def f(x, y):
a.assign(y * b)
b.assign_add(x * a)
return a + b
f(1.0, 2.0) # 10.0
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>
变量也可能使动态图模式的执行结果和静态图模式产生差异。当每次调用创建一个新变量时,由于 tracing
语义,tf.function
会在每次调用时重用同一变量,但动态图时会在每次调用时新建相应的变量。为了避免类似的问题,tf.function
将在监测到危险的变量创建行为时抛出异常。
@tf.function
def f(x):
v = tf.Variable(1.0)
v.assign_add(x)
return v
with assert_raises(ValueError):
f(1.0)
<ipython-input-17-73e410646579>:3 f *
v = tf.Variable(1.0)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
return cls._variable_v2_call(*args, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
shape=shape)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:65 getter
return captured_getter(captured_previous, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py:502 invalid_creator_scope
"tf.function-decorated function tried to create "
ValueError: tf.function-decorated function tried to create variables on non-first call.
而无歧义的代码就不会触发相应的行为
v = tf.Variable(1.0)
@tf.function
def f(x):
return v.assign_add(x)
print(f(1.0)) # 2.0
print(f(2.0)) # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
只要可以证明tf.function
中的变量只在函数初次执行时被创建,也是可以通过检查的。
class C:
pass
obj = C()
obj.v = None
@tf.function
def g(x):
if obj.v is None:
obj.v = tf.Variable(1.0)
return obj.v.assign_add(x)
print(g(1.0)) # 2.0
print(g(2.0)) # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
AutoGraph
集成在 tf.function
中,用来重写依赖与张量的条件分支和循环,以支持计算图动态改变结构。
tf.cond
和 tf.while_loop
依旧与 tf.function
兼容,但以命令式写控制流更简单直观。
# Simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.654489756 0.378447413 0.843570352 0.706569076 0.899703264]
[0.57468468 0.361358374 0.687695503 0.608520865 0.716153383]
[0.518791437 0.346409947 0.596499443 0.543085039 0.614520967]
[0.476766676 0.333187848 0.534554 0.495319664 0.54730171]
[0.443650424 0.321382284 0.488854468 0.458428353 0.498495191]
[0.416665703 0.310756236 0.453306764 0.428802431 0.460932881]
[0.394117773 0.301124901 0.424613386 0.40432 0.430844247]
[0.374904692 0.292341709 0.400809854 0.383639246 0.406026632]
[0.358274311 0.284288675 0.380641699 0.36586377 0.385093749]
[0.343693078 0.276869684 0.36326462 0.3503685 0.367122918]
[0.330770403 0.270005435 0.348086357 0.336702287 0.351472586]
[0.319212824 0.263629884 0.334677339 0.324530154 0.337680846]
[0.308794975 0.257687539 0.322717279 0.313597322 0.325405359]
[0.299340427 0.252131343 0.31196183 0.303706169 0.314386249]
[0.290708899 0.246921107 0.302220792 0.294700593 0.30442214]
[0.282787144 0.242022276 0.293343604 0.286455452 0.295354247]
[0.275482684 0.237404957 0.285209358 0.278869152 0.287055373]
[0.268719077 0.233043134 0.277719557 0.271858126 0.279422313]
[0.262432516 0.228914022 0.27079317 0.265352964 0.272370338]
[0.256569326 0.22499761 0.264362723 0.259295493 0.265829057]
[0.25108391 0.221276194 0.258371592 0.25363645 0.259739518]
[0.245937288 0.217734098 0.252771795 0.248333916 0.254051864]
[0.241095871 0.214357331 0.247522414 0.243351877 0.248723671]
[0.236530572 0.211133406 0.242588282 0.238659233 0.24371852]
[0.23221606 0.2080511 0.237939 0.234228939 0.239004955]
[0.228130147 0.205100328 0.233548105 0.230037391 0.234555662]
[0.224253282 0.202271983 0.229392484 0.226063833 0.230346799]
[0.220568195 0.199557811 0.225451797 0.222289979 0.226357415]
[0.217059553 0.196950331 0.221708104 0.218699604 0.222569034]
[0.213713691 0.194442704 0.21814549 0.215278283 0.218965292]
[0.210518375 0.192028716 0.214749783 0.212013125 0.215531647]
[0.207462624 0.189702675 0.211508334 0.208892584 0.21225509]
[0.204536542 0.18745935 0.208409771 0.205906287 0.209123984]
[0.201731205 0.185293958 0.205443889 0.203044847 0.206127867]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.19903852, 0.18320207, 0.20260146, 0.20029978, 0.20325728],
dtype=float32)>
如下代码可以观察 autograph
生成的代码
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
print(tf.autograph.to_code(f))
def tf__f(x):
do_return = False
retval_ = ag__.UndefinedReturnValue()
with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
def get_state():
return ()
def set_state(_):
pass
def loop_body(x):
ag__.converted_call(tf.print, (x,), None, fscope)
x = ag__.converted_call(tf.tanh, (x,), None, fscope)
return x,
def loop_test(x):
return ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1
x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())
do_return = True
retval_ = fscope.mark_return_value(x)
do_return,
return ag__.retval(retval_)
AutoGraph
将 if
语句转换为等效的 tf.cond
调用。这一替换发生在条件变量为张量时,除此之外的条件,在 tracing
时确定。
test_tf_cond
函数用来检查函数中是否使用了 tf.cond
def test_tf_cond(f, *args):
g = f.get_concrete_function(*args).graph
if any(node.name == 'cond' for node in g.as_graph_def().node):
print("{}({}) uses tf.cond.".format(
f.__name__, ', '.join(map(str, args))))
else:
print("{}({}) executes normally.".format(
f.__name__, ', '.join(map(str, args))))
print(" result: ",f(*args).numpy())
当参数为 python True
时,正常地执行条件:
@tf.function
def dropout(x, training=True):
if training:
x = tf.nn.dropout(x, rate=0.5)
return x
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)
dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), True) executes normally.
result: [2. 2. 2. 2. 0. 0. 2. 0. 0. 2.]
但传递一个张量则会使 python if
替换为 tf.cond
:
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))
dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), tf.Tensor(True, shape=(), dtype=bool)) uses tf.cond.
result: [0. 2. 2. 2. 0. 0. 0. 0. 2. 2.]
AutoGraph
有一些转换循环的简单规则。
for
:迭代器是张量时转换while
:循环条件与张量有关时转换
如果一个循环被转换,它将由 tf.while
动态展开,或在 for x in tf.data.Dataset
情况下,将循环转换为 tf.data.Dataset.reduce
。
如果循环没有被转换,则静态展开。
test_dynamically_unrolled(f, *args)
def test_dynamically_unrolled(f, *args):
g = f.get_concrete_function(*args).graph
if any(node.name == 'while' for node in g.as_graph_def().node):
print("{}({}) uses tf.while_loop.".format(
f.__name__, ', '.join(map(str, args))))
elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
print("{}({}) uses tf.data.Dataset.reduce.".format(
f.__name__, ', '.join(map(str, args))))
else:
print("{}({}) gets unrolled.".format(
f.__name__, ', '.join(map(str, args))))
tf.function
的静态展开
@tf.function
def for_in_range():
x = 0
for i in range(5):
x += i
return x
test_dynamically_unrolled(for_in_range)
for_in_range() gets unrolled.
@tf.function
def for_in_tfrange():
x = tf.constant(0, dtype=tf.int32)
for i in tf.range(5):
x += i
return x
test_dynamically_unrolled(for_in_tfrange)
for_in_tfrange() uses tf.while_loop.
@tf.function
def for_in_tfdataset():
x = tf.constant(0, dtype=tf.int64)
for i in tf.data.Dataset.range(5):
x += i
return x
test_dynamically_unrolled(for_in_tfdataset)
for_in_tfdataset() uses tf.data.Dataset.reduce.
@tf.function
def while_py_cond():
x = 5
while x > 0:
x -= 1
return x
test_dynamically_unrolled(while_py_cond)
while_py_cond() gets unrolled.
@tf.function
def while_tf_cond():
x = tf.constant(5)
while x > 0:
x -= 1
return x
test_dynamically_unrolled(while_tf_cond)
while_tf_cond() uses tf.while_loop.
感兴趣的可以去看原文 Better performance with tf.function Functions, not Sessions AutoGraph Reference tf.Module Tensorflow2.0 学习笔记之静态图转换器 Tensorflow2.0 学习笔记之状态容器