Improving the lowering and compilation of unrolled lax.scan
loops
#25336
Replies: 2 comments
-
I created the following example: https://gist.github.com/carlosgmartin/a3055c7605157a54d48d108226a48b97. Output:
As you can see, the lowering time is faster, and the lowered expression smaller, for my (partial) re-implementations of It would be nice if there was a StableHLO primitive equivalent to the In other words, we could optimize the loop body itself, but otherwise treat it as a "unit", and then "copy-paste" or "tile" it repeatedly at the very end of compilation. |
Beta Was this translation helpful? Give feedback.
-
I've also opened an issue about this at openxla/stablehlo#2664. |
Beta Was this translation helpful? Give feedback.
-
lax.scan
being slow is a common issue. Here are a few examples:jit
functions are re-traced (and re-compiled) #7155 (comment)scan
injnp.searchsorted
, when method 'scan_unrolled' is specified. On GPU, XLA's 'scan' (fori_loop) implementation launches multiple calls to the body_fun GPU kernel, whereas a fully unrolled scan can be fused into a single kernel launch. #17509lax.scan
onmap_coordinates
slower on GPU than on CPU? #10794 (comment)Often, the reason for this slowness is that
lax.scan
causes multiple kernel launches on GPU, as discussed here.One solution to this problem is to use unrolling. The disadvantage of that solution is that it increases lowering and compilation times, sometimes dramatically so.
If JAX improved the lowering and compilation of such unrolled loops, it would allow users to get the best of both worlds: Fast execution and fast lowering/compilation.
Before we get to compilation, let's see how much we can optimize lowering.
Let's start with a simple example, which computes discounted returns for reinforcement learning:
Here is the output:
As you can see, there are multiple patterns that repeat 10 times (which is the number of
steps
), as well as a single StableHLO function at the end, which is called once per step.As a first step toward improving lowering/compile time (which could also facilitate further lowering/compile time optimizations downstream), it seems to me that it should be possible to fold all of the recurring patterns into a single StableHLO function like the one at the end, so that the only thing that gets repeated 10 times are
call
instructions, just one per step. In particular, we should be able to fold the following repeating pattern into a single function:Would it be possible to do that, as a start?
Beta Was this translation helpful? Give feedback.
All reactions