Skip to content

TraceAnnotation not showing inside jax.jit #12381

Open
@ilemhadri

Description

Description

With this snippet of code,

import jax
import jax.numpy as jnp

@jax.jit
def compute(M):
    Z = M@M
    with jax.profiler.TraceAnnotation("InsideAnnotation"):
        for _ in range(10):
            Z = jnp.sqrt(Z@Z)
            _, _ = jnp.linalg.eigh(Z)
    return 0

dim = 1000
u = jax.random.multivariate_normal(key = jax.random.PRNGKey(120), mean = jnp.zeros(dim), cov = jnp.identity(dim), shape = (dim,dim))
with jax.profiler.trace("/home/ismlemhadri/research/ada/jax-trace", create_perfetto_link=True):
    result = jax.block_until_ready(compute(u@u))

I get the following trace on Perfetto: https://i.imgur.com/iTXI8Ic.png.
So the trace only shows a JaxCompiledFunction whereas I would expect to see InsideAnnotation as well.

This is the smallest reproducible example I could think of. In my use case, I have many more annotations and observe similar issues.

What jax/jaxlib version are you using?

0.3.17, jaxlib 0.3.15

Which accelerator(s) are you using?

CPU

Additional System Info

Python 3.8.13

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions