Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vectorized_map causes tf.function retracing. #241

Open
sebastian-sz opened this issue Mar 30, 2022 · 20 comments
Open

vectorized_map causes tf.function retracing. #241

sebastian-sz opened this issue Mar 30, 2022 · 20 comments
Labels

Comments

@sebastian-sz
Copy link
Contributor

Problem description

It seems like applying some layers that use BaseImageAugmentationLayer and self.auto_vectorize=True, over batched input are causing tf.function retracing:

layer = Solarization()  # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)

for _ in range(50):
    dummy_input = rng.uniform(
        shape=(1, 224, 224, 3), minval=0, maxval=255
    )
    layer(dummy_input)

raises

WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7f80a2a544c0> triggered tf.function retracing. (...)

Benchmarks

Running simple benchmarks confirms performance degradation with tf.function and batched input:

use_tf_function = False

rng = tf.random.Generator.from_seed(1234)
layer = Solarization()
results = []

if use_tf_function:
    layer.augment_image = tf.function(layer.augment_image, jit_compile=True)

# Warmup
for _ in range(10):
    layer(rng.uniform(shape=(24, 224, 224, 3), maxval=256))

# Benchmark
for _ in range(100):
    dummy_input = rng.uniform(shape=(24, 224, 224, 3), maxval=256)
    start = time.perf_counter()
    layer(dummy_input)
    stop = time.perf_counter()
    results.append(stop-start)

print(tf.reduce_mean(results))

Case 1: auto_vectorize=True

Without tf.function 0.067 ms.
With: 0.079 ms.

Case 2: auto_vectorize=False

The issue doesn't pop up with non-batched input e. g. (224, 224, 3) or if one changes self.auto_vectorize=False in the layer.

Setting self.auto_vectorize=False will yield:
Withouth tf.function: 0.017 ms
With: 0.013 ms.

Case 3: override _batch_augment (if possible)

In case of vectorized operations, the fastest option is still overriding _batch_augment to return self._augment(inputs). This will yield:
Without tf.function: 0.0059 ms
With: 0.0016 ms

@sebastian-sz
Copy link
Contributor Author

Improving the performance of Solarization is something I wanted to discuss in another issue.

I wanted to point out that multiple keras_cv preprocessing layers are affected by retracing, when applied on batched input.

@bhack
Copy link
Contributor

bhack commented Mar 30, 2022

I don't think we are going to be impacted by the autovectorizzation/retracing with real use cases:

import tensorflow as tf
from keras_cv.layers.preprocessing import Solarization
from tensorflow.keras.models import Sequential

layer = Solarization()  # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)

from random import randint

model = Sequential()
model.add(layer)
model.build([24,224,224,3])

for x in range(50):
    x = rng.uniform(
        shape=(24,224, 224, 3), minval=0, maxval=255, dtype=tf.float32)
    _ = model.predict(x)

See my comments in tensorflow/tensorflow#42441

@LukeWood
Copy link
Contributor

LukeWood commented Mar 31, 2022

Thanks for the detailed report @sebastian-sz

FYI @qlzh727

@sebastian-sz
Copy link
Contributor Author

@bhack fair point - using Sequential and .predict method silences the warnings and unifies inference time to be ~0.024ms regardless of whether self.auto_vectorize is True or False.

This is however a bit slower than calling the layer directly with self.auto_vectorize=False (0.013ms) or native vectorization (0.0016 ms).

Also, model.predict cannot be used inside tf.data.Dataset map function - one needs to rely on __call__ methods. I'm unsure how the execution works inside tf.data.Dataset - I do see the differences in inference time, depending on self.auto_vectorize but there are no warning regarding retracing.

import time
import tensorflow as tf
from keras_cv.layers import Solarization

model = tf.keras.Sequential()
model.add(Solarization())
model.build([24, 224, 224, 3])

rng = tf.random.Generator.from_seed(1234)
ds = tf.data.Dataset.from_tensor_slices([rng.uniform(shape=(24, 224, 224, 3), maxval=256)]).repeat(100)
ds = ds.map(lambda x: model(x))

for _ in ds:
    continue

start = time.perf_counter()
for _ in ds:
    continue
stop = time.perf_counter()

print((stop - start) / 100)

@sebastian-sz
Copy link
Contributor Author

It seems like wrapping the entire layer in tf.function (even better if with jit_compile=True) also silences the warnings and provides decent performance in eager mode:

@tf.function(jit_compile=True)
def apply(x):
    return layer(x)

0.0015ms for self.auto_vectorize=True
0.0022ms for self.auto_vectorize=False
0.0015ms for native vectorization.

This issue can be closed from my end. If no further comments appear I will close this issue starting next week.
Thanks for the help!

@bhack
Copy link
Contributor

bhack commented Mar 31, 2022

Generally It Is not the best solution to benchmark in the loop with predict:

https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call

For controlling the XLA compilation see my prpoposal at:
#165 (comment)

@LukeWood
Copy link
Contributor

I wonder if we should @tf.function our call methods by default to:

a.) mute warnings
b.) make performance consistent.

@sebastian-sz sebastian-sz changed the title vectorized_map causes tf.function retracing and lowers performance with tf.function. vectorized_map causes tf.function retracing. Mar 31, 2022
@bhack
Copy link
Contributor

bhack commented Mar 31, 2022

@sebastian-sz Can you try your initial example with the last tf-nightly version?

@sebastian-sz
Copy link
Contributor Author

@bhack Running with 2.9.0-dev20220329 gives very similar numbers and retracing persists.

I am however happy with the performance from tf.function wrapper.

@bhack
Copy link
Contributor

bhack commented Mar 31, 2022

I wonder if we should @tf.function our call methods by default to:

a.) mute warnings b.) make performance consistent.

Vectorized map is going internally to trace the function of we are in the default eager mode but model are by default tf.function wrapped.

If we want to maintain the critical section eager-compatible we need to automate the conditonal call on standard map_fn in the base class overload we have done (we are in eager mode).

@bhack
Copy link
Contributor

bhack commented Mar 31, 2022

More in general I think that this use of "layer as op" is still a little bit confusing:

#122 (comment)

@LukeWood
Copy link
Contributor

LukeWood commented Apr 3, 2022

It seems like wrapping the entire layer in tf.function (even better if with jit_compile=True) also silences the warnings and provides decent performance in eager mode:

@tf.function(jit_compile=True)
def apply(x):
    return layer(x)

0.0015ms for self.auto_vectorize=True 0.0022ms for self.auto_vectorize=False 0.0015ms for native vectorization.

This issue can be closed from my end. If no further comments appear I will close this issue starting next week. Thanks for the help!

I don’t want to close it yet because I feel we can need to figure out how to effectively communicate this recommendation to users 🤔

@LukeWood
Copy link
Contributor

LukeWood commented Apr 3, 2022

I wonder if we should @tf.function our call methods by default to:
a.) mute warnings b.) make performance consistent.

Vectorized map is going internally to trace the function of we are in the default eager mode but model are by default tf.function wrapped.

If we want to maintain the critical section eager-compatible we need to automate the conditonal call on standard map_fn in the base class overload we have done (we are in eager mode).

We could also @tf.function the base layers call method if needed. Or the augment batch method

@bhack
Copy link
Contributor

bhack commented Apr 3, 2022

We could also @tf.function the base layers call method if needed. Or the augment batch method

It really depends.. do you want to silently be in graph mode with some functions?

As the end user/developer doesn't control the vectorization in the API it is something that you are going to do behind the scene without any notification.

At least model.compile still give the control to the end user for both eager and jit_compile (XLA) with its own args:

compile(
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    steps_per_execution=None,
    jit_compile=None,
    **kwargs
)

@bhack
Copy link
Contributor

bhack commented Apr 6, 2022

/cc @mdanatg
I suppose that the situation isn't evolved since Oct 2020 tensorflow/tensorflow#43710 (comment). What do you think?

@mdanatg
Copy link

mdanatg commented Apr 6, 2022

I think we now have better mechanisms to protect against excessive retracing. Is the error coming from a standard Keras layer, or is it a custom one?

@bhack
Copy link
Contributor

bhack commented Apr 6, 2022

As we don't have an object with tf.vectorized_map it is hard to not retrace the function.

So I still believe that is better to automatically call map_fn in eager mode and tf.vectorized_map in graph mode.

@bhack
Copy link
Contributor

bhack commented Apr 6, 2022

Cause when the graph creation it is done implicitly by API design like in tf.data it is documented explicitely:

https://www.tensorflow.org/api_docs/python/tf/data/Dataset?hl=en#map

map_func can accept as arguments and return any type of dataset element.
Note that irrespective of the context in which map_func is defined (eager vs. graph), tf.data traces the function and executes it as a graph. To use Python code inside of the function you have a few options:

  1. Rely on AutoGraph to convert Python code into an equivalent graph computation. The downside of this approach is that AutoGraph can convert some but not all Python code.
  2. Use tf.py_function, which allows you to write arbitrary Python code but will generally result in worse performance than 1). For example:

@LukeWood
Copy link
Contributor

Ok, after lots more digging and time I agree with you @bhack we should apply map_fn in eager, vectorized in graph. We can tackle this after @divyashreepathihalli migrates BaseImageAugmentationLayer to KerasCV. It will be easier to update when in KerasCV

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants