Skip to content

jax_jit.cc support for Tracer: don't cache_miss #10976

@lukemetz

Description

@lukemetz

Working with complex jit-ed functions can lead to long compile times and hurt interactive workflows. One option is to jit only pieces of the function. This works great if only jit is the top level function transformation in all cases. If this is not the case, say when using other transforms, things dramatically slow down due to python overhead. For example:

@jax.jit
def fun(x):
  return x * 2

def grad(y):
  for i in range(10):
    y = fun(y)
  return y

grad = jax.grad(grad)(1.)

Upon investigation, jax's c++ jit codepath starts missing caches (cache_miss in profiler) and falls back to the python jit codepath. VLog errors show this. For example with BatchTracers, and VJPTracers. This is quite slow -- taking 33ms in my case where the actual computation takes 1.1ms.

I0603 13:10:37.163035 3754 jax_jit.cc:892] ComputeSignature failed: INVALID_ARGUMENT: Not supported: The C++ ToPyArgSignature only accepts Buffer/DeviceArray/ShardedDeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax.interpreters.batching.BatchTracer'>

I0603 13:10:37.186771 3754 jax_jit.cc:892] ComputeSignature failed: INVALID_ARGUMENT: Not supported: The C++ ToPyArgSignature only accepts Buffer/DeviceArray/ShardedDeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax.interpreters.ad.JVPTracer'>

Would it be possible to add paths, or somehow strip the tracer before executing the CPP codepath?

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions