Open
Description
Description
With this snippet of code,
import jax
import jax.numpy as jnp
@jax.jit
def compute(M):
Z = M@M
with jax.profiler.TraceAnnotation("InsideAnnotation"):
for _ in range(10):
Z = jnp.sqrt(Z@Z)
_, _ = jnp.linalg.eigh(Z)
return 0
dim = 1000
u = jax.random.multivariate_normal(key = jax.random.PRNGKey(120), mean = jnp.zeros(dim), cov = jnp.identity(dim), shape = (dim,dim))
with jax.profiler.trace("/home/ismlemhadri/research/ada/jax-trace", create_perfetto_link=True):
result = jax.block_until_ready(compute(u@u))
I get the following trace on Perfetto: https://i.imgur.com/iTXI8Ic.png.
So the trace only shows a JaxCompiledFunction
whereas I would expect to see InsideAnnotation
as well.
This is the smallest reproducible example I could think of. In my use case, I have many more annotations and observe similar issues.
What jax/jaxlib version are you using?
0.3.17, jaxlib 0.3.15
Which accelerator(s) are you using?
CPU
Additional System Info
Python 3.8.13