-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
Working with complex jit-ed functions can lead to long compile times and hurt interactive workflows. One option is to jit only pieces of the function. This works great if only jit is the top level function transformation in all cases. If this is not the case, say when using other transforms, things dramatically slow down due to python overhead. For example:
@jax.jit
def fun(x):
return x * 2
def grad(y):
for i in range(10):
y = fun(y)
return y
grad = jax.grad(grad)(1.)
Upon investigation, jax's c++ jit codepath starts missing caches (cache_miss in profiler) and falls back to the python jit codepath. VLog errors show this. For example with BatchTracers, and VJPTracers. This is quite slow -- taking 33ms in my case where the actual computation takes 1.1ms.
I0603 13:10:37.163035 3754 jax_jit.cc:892] ComputeSignature failed: INVALID_ARGUMENT: Not supported: The C++ ToPyArgSignature only accepts Buffer/DeviceArray/ShardedDeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax.interpreters.batching.BatchTracer'>
I0603 13:10:37.186771 3754 jax_jit.cc:892] ComputeSignature failed: INVALID_ARGUMENT: Not supported: The C++ ToPyArgSignature only accepts Buffer/DeviceArray/ShardedDeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax.interpreters.ad.JVPTracer'>
Would it be possible to add paths, or somehow strip the tracer before executing the CPP codepath?