-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inner jit
functions are re-traced (and re-compiled)
#7155
Comments
Thanks for observing this, and describing it so clearly! See also #4572 (the first two comments in the thread), and #3847. This behavior is currently intended, roughly speaking. The XLA compiler effectively inlines all code, in part to make buffer assignment easier, and that means that it can't reuse subfunction code in this way. That said, even though XLA won't avoid recompiling these sub-functions, JAX could avoid re-tracing them. (Moreover, it's plausible that XLA will be upgraded to avoid recompilations of these subfunctions in the foreseeable future, at least on certain backends, like CPU. And we'd need to avoid retracing to be able to leverage that.) This issue might be a duplicate of #3847, though it's not clear: in that thread @shoyer had reason to believe JAX tracing was slow, but that could either be a retracing issue (as described here) or just having high tracing overheads (or the time could've been spent in something else, since IIUC the only directly-measured tracing times in that thread proved to be small).
A classic dilemma! You could try using
Maybe using |
jit
decorators are ignored for nested functionsjit
functions are re-traced (and re-compiled)
A concrete next step on our end is to add thorough logging so that these times are super easy to inspect. |
Thanks for the thorough response and the feedback.
I actually just did that and it's nice that it allows to select the best trade-off with the unroll parameter. The problem I am facing seems to be at a different level though, because for some reason jax/XLA really doesn't want to fuse the operations in the looped operation (what is
Thanks, I also saw that in the profiler in between my typing down the issue and reading your response :)
I think that would be good, even if only to avoid confusion for the users ;) |
Hrm, my very basic understanding is that XLA:GPU will codegen kernels whenever it doesn't need cuBLAS / cuDNN (as I recently commented here and here). But from your description of bitwise operations and array index munging, those things sound plausibly XLA-codegen-able. If you could isolate a repro, I could forward it to the XLA:GPU folks so they can look at whether things should be fused better. |
Looking into my problem a bit further, I don't think it is a jax issue so I don't want to open another thread for it. It seems we are just hitting some built-in thresholds upon which XLA prefers to split operations into several kernels which are then executed sequentially. I would like to force it to not do that but this seems hard-coded. I'm not 100% sure that this is what happens, because changing the code a bit (removing/changing the jax.lax.gather/jnp.take operations I have in there) changes this behavior and I can get a fused kernel whose operation count far exceeds the threshold I found in XLA code. With the gather, depending on how complex I make the function, I can also provoke the splits to occur in different places, so I think there are several splitting heuristics at work, which makes this really confusing to look at. |
For custom kernels, take a look at dfm/extending-jax too!
Amazing spelunking! But if you want some more info, and maybe even some improvements, from real live XLA:GPU team members about this, we just have to give them a repro to poke at :) |
Alright, I tried to boil it down to the essence in the following. So basically I have this piece of code which repeatedly applies a vector-valued-function ( import jax
import jax.numpy as jnp
import numpy as np
def col_fn(col): # applied to a column of a 4x4 matrix
for _ in range(3): # number of operations inside can influence fusing behaviour
col = col << 7
return col
def inner_fn(m):
col_vmap = jax.vmap(col_fn, in_axes=1, out_axes=1)
# col_vmap = col_fn # vmap or no vmap seems to make NO difference
# apply col_fn to each column
m = col_vmap(m)
# then apply col_fn to each diagonal:
# moving diagonals into columns
diag_map = np.array([
0, 1, 2, 3,
5, 6, 7, 4,
10, 11, 8, 9,
15, 12, 13, 14
])
# option 1: jax.lax.gather : results in splits; split always occur after a gather or concatenate, not necessarily aligned with iteration counts
gdn = jax.lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,1), start_index_map=(0,1))
indices = jnp.array(((0,0), (0,1), (0,2), (0,3), (1,1), (1,2), (1,3), (1,0), (2,2), (2,3), (2,0), (2,1), (3,3), (3,0), (3,1), (3,2)))
diag_m = jax.lax.gather(m, indices, gdn, slice_sizes=(1,1), unique_indices=True).reshape((4,4))
# option 2: jnp.take : same as option1 (but with some additional instructions that appear to be argument checks)
diag_map = diag_map.reshape((4,4))
diag_m = jnp.take(m, diag_map)
# option 3: flatten - lookup - reshape : seems to either fuse completely or create a fuse per iteration in outer_fn, depending on iterations in col_fn
# diag_m = m.ravel()[diag_map].reshape(4,4)
# option 4: no index mapping : will always result in a single large fuse (no matter how complex col_fn)
# diag_m = m
# diag_m = col_vmap(diag_m)
return diag_m
@jax.jit
def outer_fn(m):
old_m = m
for _ in range(7): # number of iterations influences number of kernels
m = inner_fn(m)
m = m + old_m # this gets absorbed into the last split
return m
m = np.arange(16).reshape((4,4))
outer_fn(m).block_until_ready() As you can see, I tried different ways of reorganising the matrix and they result in different fusing outcomes, which I investigated by looking at the outputs produced when setting the
Further: In reality, def col_fn_complex(col): # applied to a column of a 4x4 matrix
a,b,c,d = col
for _ in range(3): # number of operations inside can influence fusing behavior
a = a << 7
b = b + a
c = c * b
col = jnp.array([a,b,c,d])
return col I haven't been able to do this without at some point using a concatenating operation to produce the output vector (or using an index_update), which results in even more splits of the fusions in the output HLO dump. Other things I tried:
|
I tried running the example that started this issue and it seems like we don't get recompilation anymore? import jax
import jax.numpy as jnp
from functools import partial
@jax.jit
def inner_fn(state):
print("entering inner_fn now")
return 2*state
@jax.jit
def outer_fn(x):
print("entering outer_fn now")
old_x = x
for _ in range(10):
x = inner_fn(x)
x = x + old_x
return x
with jax.log_compiles(True):
state = jnp.arange(5, dtype=jnp.uint32)
inner_fn(state)
outer_fn(state) output:
|
@KeAWang thanks so much for checking! That makes sense: since |
It seems that
jit
decorators are ignored for nested functions. Consider the following example (simplified from an actual program, whereinner_fn
is slightly more complex):Expected behavior: As both functions are decorated with jit,
inner_fn
should be compiled once and the result re-used by the plain for loop inside ofouter_fn
. Consequently, the output would show only one emission of "entering inner_fn now".Actual behavior: The jit decorator for
inner_fn
seems to be ignored: "entering inner_fn now" is printed 10 times and setting env varJAX_LOG_COMPILES=1
only prints one line forCompiling outer_fn
. In short, we observe the same behavior as if the jit decorator were absent forinner_fn
.Why a bug?: The current behavior negates any of the caching done for jitted functions, resulting in repeated tracing of a function with identical tracers and thus inflated compilation times and can easily surprise the user.
Other considerations:
fori_loop
: Using afori_loop
results ininner_fn
being visited only once, so that will mitigate the issue for this particular example. However, we were actually having troubles with it as it prevents unrolling the loop and ourinner_fn
was small enough so that we actually were seeing performance hits from cuda kernel launches forinner_fn
in this case.inline=True
for jit decorator ofinner_fn
: In this case the compilation will compile 10 separate sub-jaxpressions forinner_fn
that are not inlined but all invoked viaxla_call
(when inspecting this viajax.make_jaxpr
). We see that the jit decorator is therefore not completely without effect, but this is arguably the worst case (separatexla_call
overhead as in thefori_loop
AND multiple compilation ofinner_fn
).Please:
The text was updated successfully, but these errors were encountered: