Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inner jit functions are re-traced (and re-compiled) #7155

Closed
3 tasks done
lumip opened this issue Jul 1, 2021 · 9 comments · Fixed by #15048
Closed
3 tasks done

inner jit functions are re-traced (and re-compiled) #7155

lumip opened this issue Jul 1, 2021 · 9 comments · Fixed by #15048
Assignees
Labels
enhancement New feature or request

Comments

@lumip
Copy link
Contributor

lumip commented Jul 1, 2021

It seems that jit decorators are ignored for nested functions. Consider the following example (simplified from an actual program, where inner_fn is slightly more complex):

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.jit, inline=True)
def inner_fn(state):
    print("entering inner_fn now")
    return 2*state

@jax.jit
def outer_fn(x):
    print("entering outer_fn now")
    old_x = x
    for _ in range(10):
        x = inner_fn(x)
    x = x + old_x
    return x

state = jnp.arange(5, dtype=jnp.uint32)
outer_fn(state).block_until_ready()

Expected behavior: As both functions are decorated with jit, inner_fn should be compiled once and the result re-used by the plain for loop inside of outer_fn. Consequently, the output would show only one emission of "entering inner_fn now".

Actual behavior: The jit decorator for inner_fn seems to be ignored: "entering inner_fn now" is printed 10 times and setting env var JAX_LOG_COMPILES=1 only prints one line for Compiling outer_fn. In short, we observe the same behavior as if the jit decorator were absent for inner_fn.

Why a bug?: The current behavior negates any of the caching done for jitted functions, resulting in repeated tracing of a function with identical tracers and thus inflated compilation times and can easily surprise the user.

Other considerations:

  1. Using fori_loop: Using a fori_loop results in inner_fn being visited only once, so that will mitigate the issue for this particular example. However, we were actually having troubles with it as it prevents unrolling the loop and our inner_fn was small enough so that we actually were seeing performance hits from cuda kernel launches for inner_fn in this case.
  2. No inline=True for jit decorator of inner_fn: In this case the compilation will compile 10 separate sub-jaxpressions for inner_fn that are not inlined but all invoked via xla_call (when inspecting this via jax.make_jaxpr). We see that the jit decorator is therefore not completely without effect, but this is arguably the worst case (separate xla_call overhead as in the fori_loop AND multiple compilation of inner_fn).

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
  • If applicable, include full error messages/tracebacks.
@lumip lumip added the bug Something isn't working label Jul 1, 2021
@mattjj mattjj added question Questions for the JAX team enhancement New feature or request and removed bug Something isn't working question Questions for the JAX team labels Jul 1, 2021
@mattjj mattjj self-assigned this Jul 1, 2021
@mattjj
Copy link
Collaborator

mattjj commented Jul 1, 2021

Thanks for observing this, and describing it so clearly!

See also #4572 (the first two comments in the thread), and #3847.

This behavior is currently intended, roughly speaking. The XLA compiler effectively inlines all code, in part to make buffer assignment easier, and that means that it can't reuse subfunction code in this way.

That said, even though XLA won't avoid recompiling these sub-functions, JAX could avoid re-tracing them. (Moreover, it's plausible that XLA will be upgraded to avoid recompilations of these subfunctions in the foreseeable future, at least on certain backends, like CPU. And we'd need to avoid retracing to be able to leverage that.)

This issue might be a duplicate of #3847, though it's not clear: in that thread @shoyer had reason to believe JAX tracing was slow, but that could either be a retracing issue (as described here) or just having high tracing overheads (or the time could've been spent in something else, since IIUC the only directly-measured tracing times in that thread proved to be small).

  1. Using fori_loop: Using a fori_loop results in inner_fn being visited only once, so that will mitigate the issue for this particular example. However, we were actually having troubles with it as it prevents unrolling the loop and our inner_fn was small enough so that we actually were seeing performance hits from cuda kernel launches for inner_fn in this case.

A classic dilemma! You could try using lax.scan along with its unroll parameter. That should save on kernel launches, though on GPU the kernel launch overheads might still be high. (On TPU the whole jitted computation is walsy compiled into one 'kernel', so there are no analogous overheads there.)

  1. No inline=True for jit decorator of inner_fn: In this case the compilation will compile 10 separate sub-jaxpressions for inner_fn that are not inlined but all invoked via xla_call (when inspecting this via jax.make_jaxpr). We see that the jit decorator is therefore not completely without effect, but this is arguably the worst case (separate xla_call overhead as in the fori_loop AND multiple compilation of inner_fn).

xla_call shouldn't incur any overheads; XLA just inlines all calls anyway. That is, while the representation is different at the jaxpr level, and in the un-optimized HLO we hand to XLA, the optimized HLO will remove these calls and inline everything.

Maybe using lax.scan along with unroll is the best next thing to try. Another potential next step is to measure the tracing times carefully; we can mitigate those by avoiding retracing, but if the times are actually in XLA recompilation then there's not much we can do on the JAX side.

@mattjj mattjj changed the title jit decorators are ignored for nested functions inner jit functions are re-traced (and re-compiled) Jul 1, 2021
@mattjj
Copy link
Collaborator

mattjj commented Jul 1, 2021

A concrete next step on our end is to add thorough logging so that these times are super easy to inspect.

@lumip
Copy link
Contributor Author

lumip commented Jul 1, 2021

Thanks for the thorough response and the feedback.

You could try using lax.scan along with its unroll parameter.

I actually just did that and it's nice that it allows to select the best trade-off with the unroll parameter. The problem I am facing seems to be at a different level though, because for some reason jax/XLA really doesn't want to fuse the operations in the looped operation (what is inner_fn in the above): in my real code I have some bitwise operations in a vmap and some (static) reordering of array indices going on there, but nothing that could not be fused from my understanding. So for some reason I end up seeing a lot of small kernels originating from this in the CUDA profiler. I initially thought that this re-tracing was to blame - but it is not. The problem persists with all variants (plain for loop, fori_loop and scan), but it's probably best I try to narrow it down a bit more and open a new issue for that (... but if there's any hint you could give just from the info above, I'd greatly appreciate that. Is there perhaps a maximum complexity above which kernels are split / not fused?).

xla_call shouldn't incur any overheads; XLA just inlines all calls anyway.

Thanks, I also saw that in the profiler in between my typing down the issue and reading your response :)

That said, even though XLA won't avoid recompiling these sub-functions, JAX could avoid re-tracing them

I think that would be good, even if only to avoid confusion for the users ;)

@mattjj
Copy link
Collaborator

mattjj commented Jul 1, 2021

but if there's any hint you could give just from the info above, I'd greatly appreciate that. Is there perhaps a maximum complexity above which kernels are split / not fused?

Hrm, my very basic understanding is that XLA:GPU will codegen kernels whenever it doesn't need cuBLAS / cuDNN (as I recently commented here and here). But from your description of bitwise operations and array index munging, those things sound plausibly XLA-codegen-able.

If you could isolate a repro, I could forward it to the XLA:GPU folks so they can look at whether things should be fused better.

@lumip
Copy link
Contributor Author

lumip commented Jul 2, 2021

Looking into my problem a bit further, I don't think it is a jax issue so I don't want to open another thread for it. It seems we are just hitting some built-in thresholds upon which XLA prefers to split operations into several kernels which are then executed sequentially. I would like to force it to not do that but this seems hard-coded. I'm not 100% sure that this is what happens, because changing the code a bit (removing/changing the jax.lax.gather/jnp.take operations I have in there) changes this behavior and I can get a fused kernel whose operation count far exceeds the threshold I found in XLA code. With the gather, depending on how complex I make the function, I can also provoke the splits to occur in different places, so I think there are several splitting heuristics at work, which makes this really confusing to look at.
Going forward, we will probably look into making this a hand-written kernel and look into using that with JAX as outline in this fairly extensive discussion here, as it seems there's not much to be done about the current behavior.

@mattjj
Copy link
Collaborator

mattjj commented Jul 2, 2021

Going forward, we will probably look into making this a hand-written kernel and look into using that with JAX as outline in this fairly extensive discussion here, as it seems there's not much to be done about the current behavior.

For custom kernels, take a look at dfm/extending-jax too!

With the gather, depending on how complex I make the function, I can also provoke the splits to occur in different places, so I think there are several splitting heuristics at work, which makes this really confusing to look at.

Amazing spelunking! But if you want some more info, and maybe even some improvements, from real live XLA:GPU team members about this, we just have to give them a repro to poke at :)

@lumip
Copy link
Contributor Author

lumip commented Jul 5, 2021

Alright, I tried to boil it down to the essence in the following. So basically I have this piece of code which repeatedly applies a vector-valued-function ( col_fn ) on all columns and (some form of) diagonals of a (4x4) matrix (this realised by inner_fn and the repeated application as a loop in outer_fn). My basic approach for inner_fn was to first vmap the function col_fn over columns, then reorganise the matrix so that the diagonals I'm interested in are put into columns and apply the vmapped col_fn again. This remapping is what seems to be causing problems for me. Code (col_fn is just a stand-in right now, with a loop to adjust its complexity (in terms of simple operations it applies)):

import jax
import jax.numpy as jnp
import numpy as np

def col_fn(col): # applied to a column of a 4x4 matrix
    for _ in range(3): # number of operations inside can influence fusing behaviour
        col = col << 7
    return col

def inner_fn(m):
    col_vmap = jax.vmap(col_fn, in_axes=1, out_axes=1)
    # col_vmap = col_fn # vmap or no vmap seems to make NO difference

    # apply col_fn to each column
    m = col_vmap(m)

    # then apply col_fn to each diagonal:
    # moving diagonals into columns
    diag_map = np.array([
         0,  1,  2,  3,
         5,  6,  7,  4,
        10, 11,  8,  9,
        15, 12, 13, 14
    ])

    # option 1: jax.lax.gather : results in splits; split always occur after a gather or concatenate, not necessarily aligned with iteration counts
    gdn = jax.lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,1), start_index_map=(0,1))
    indices = jnp.array(((0,0), (0,1), (0,2), (0,3), (1,1), (1,2), (1,3), (1,0), (2,2), (2,3), (2,0), (2,1), (3,3), (3,0), (3,1), (3,2)))
    diag_m = jax.lax.gather(m, indices, gdn, slice_sizes=(1,1), unique_indices=True).reshape((4,4))

    # option 2: jnp.take : same as option1 (but with some additional instructions that appear to be argument checks)
    diag_map = diag_map.reshape((4,4))
    diag_m = jnp.take(m, diag_map)

    # option 3: flatten - lookup - reshape : seems to either fuse completely or create a fuse per iteration in outer_fn, depending on iterations in col_fn
    # diag_m = m.ravel()[diag_map].reshape(4,4)

    # option 4: no index mapping : will always result in a single large fuse (no matter how complex col_fn)
    # diag_m = m
    # diag_m = col_vmap(diag_m)

    return diag_m


@jax.jit
def outer_fn(m):
    old_m = m
    for _ in range(7): # number of iterations influences number of kernels
        m = inner_fn(m)
    m = m + old_m # this gets absorbed into the last split
    return m

m = np.arange(16).reshape((4,4))
outer_fn(m).block_until_ready()

As you can see, I tried different ways of reorganising the matrix and they result in different fusing outcomes, which I investigated by looking at the outputs produced when setting the --xla_dump_hlo_as_text XLA flag.

  • Option 1 and 2 are basically identical and insert random splits after emitted gather operations in HLO (i.e., a fusion will always end with a gather op). This roughly merges 2 iterations of inner_fn, so for 7 iterations I end up with 4 fusions: The first covering the first 1.5 iterations (stopping at the reorganising in the 2nd iterations), the second fusion covers the second half of the 2nd iteration, the entire 3rd and the first half of the 4th, and so on). This seems invariant of the amount of operations in col_fn or the number of iterations of inner_fn.
  • Option 3: This either results in a single fusion combining all iterations OR in one fusion per iteration of inner_fn, depending entirely on the complexity/number of iterations in col_fn: In my tests, I get a single fusion for less than 13 iterations in col_fn and one fusion per inner_fn invocation for more. The splits again occur at the gather ops, so the first fusion covers the first half of the first invocation of inner_fn, each following fusion cover a second half and the following first half of the next iteration.
  • Option 4: No reorganising of the matrix takes place, therefore there are no gather ops emitted and the result is a single fusion in all cases, no matter the complexity of col_fn or the number of iterations of inner_fn.

Further: In reality, col_fn is a bit more involved than shown above and requires interdependent updates to the components of the received vector. Something akin to this (with some randomly chosen elementary operations)

def col_fn_complex(col): # applied to a column of a 4x4 matrix
    a,b,c,d = col
    for _ in range(3): # number of operations inside can influence fusing behavior
        a = a << 7
        b = b + a
        c = c * b
    col = jnp.array([a,b,c,d])
    return col

I haven't been able to do this without at some point using a concatenating operation to produce the output vector (or using an index_update), which results in even more splits of the fusions in the output HLO dump.

Other things I tried:

  • I have tried a different implementation in which col_fn receives the full matrix and a list of indices instead of a vector slice of the matrix, to avoid having to reorganise the matrix. However, this is merely moving the gather from the reorganising in inner_fn to the then required index lookup at the start of col_fn..
  • I have also tried extracting the diagonals by multiplying the matrix with a 0/1-masking matrix followed by a summation along the rows. Doing that, I then get the same result as in Option 3 above, with splits occurring at the reduce ops that replace the gather ops (+ additional add kernels for each iteration that are used by the reduce ops).

@KeAWang
Copy link

KeAWang commented Mar 17, 2023

I tried running the example that started this issue and it seems like we don't get recompilation anymore?

import jax
import jax.numpy as jnp
from functools import partial

@jax.jit
def inner_fn(state):
    print("entering inner_fn now")
    return 2*state

@jax.jit
def outer_fn(x):
    print("entering outer_fn now")
    old_x = x
    for _ in range(10):
        x = inner_fn(x)
    x = x + old_x
    return x

with jax.log_compiles(True):
    state = jnp.arange(5, dtype=jnp.uint32)
    inner_fn(state)
    outer_fn(state)

output:

Finished tracing + transforming inner_fn for pjit in 0.002679586410522461 sec
Compiling inner_fn for with global shapes and types [ShapedArray(uint32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
Finished XLA compilation of jit(inner_fn) in 0.033182382583618164 sec
Finished tracing + transforming outer_fn for pjit in 0.00830388069152832 sec
Compiling outer_fn for with global shapes and types [ShapedArray(uint32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
Finished XLA compilation of jit(outer_fn) in 0.062166690826416016 sec
entering inner_fn now
entering outer_fn now

@mattjj
Copy link
Collaborator

mattjj commented Mar 17, 2023

@KeAWang thanks so much for checking! That makes sense: since JAX_JIT_PJIT_API_MERGE=1 is now not only the default but the only option, we don't re-trace in these cases like we used to! Let's add this code as a test and then close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
3 participants