Skip to content

nnx.tabulate expects concrete values in unreleased flax>0.12.0 #5067

@DBraun

Description

@DBraun

I have the latest flax installed with pip install -U git+https://github.com/google/flax.git (specifically 74985b2). nnx.tabulate is not working because it seems to require shapes to be concrete.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): WSL2 on Windows Pro
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
Name: flax
Version: 0.12.0
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author:
Author-email: Flax team <flax-dev@google.com>
License:
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: clu, evosax, jax-ai-stack, jraphx
---
Name: jax
Version: 0.8.0
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: chex, clu, evosax, flax, jax-ai-stack, jaxloudnorm, jraphx, optax, orbax-checkpoint, orbax-export
---
Name: jaxlib
Version: 0.8.0
Summary: XLA library for JAX
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: ml_dtypes, numpy, scipy
Required-by: chex, clu, jax, jraphx, optax, orbax-export
  • Python version: 3.12
  • GPU/TPU model and memory: RTX 4080
  • CUDA version (if applicable): 12

Problem you have encountered:

MRE:

from jax import numpy as jnp
from flax import nnx


class Net(nnx.Module):
    def __init__(self):
        self.rngs = nnx.Rngs(0)

    def __call__(self, x):
        return self.rngs.uniform((x.shape[0], 10))


if __name__ == '__main__':
    net = Net()
    x = jnp.zeros((4, 8))
    print("running forward pass")
    y = net(x)
    print("running tabulate")
    print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
    print("all done")

What you expected to happen:

With flax 0.12.0 from PyPI, the output is

running forward pass
running tabulate
                           Net Summary                           
┏━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ path         ┃ type ┃ inputs       ┃ outputs       ┃ RngState ┃
┡━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│              │ Net  │ float32[4,8] │ float32[4,10] │ 2 (12 B) │
├──────────────┼──────┼──────────────┼───────────────┼──────────┤
│ rngs/uniform │ Rngs │ - 4          │ float32[4,10] │ 2 (12 B) │
│              │      │ - 10         │               │          │
├──────────────┼──────┼──────────────┼───────────────┼──────────┤
│              │      │              │         Total │ 2 (12 B) │
└──────────────┴──────┴──────────────┴───────────────┴──────────┘
                                                                 
                   Total Parameters: 2 (12 B)                    

all done

Logs, error messages, etc:

Output from MRE above:

running forward pass
running tabulate
Traceback (most recent call last):
  File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2025.2/scratches/scratch_33.py", line 19, in <module>
    print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 385, in tabulate
    jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs)
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 475, in trace
    traced = self.jitted_fn.trace(*pure_args, **pure_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
    out = self.f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 215, in wrapper
    return f(obj, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2025.2/scratches/scratch_33.py", line 10, in __call__
    return self.rngs.uniform((x.shape[0], 10))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 447, in __call__
    pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
                                               ^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
    out = self.f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 215, in wrapper
    return f(obj, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py", line 62, in rngs_random_method
    return random_f(self(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 425, in uniform
    shape = core.canonicalize_shape(shape)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>, JitTracer<~int32[]>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function rngs_random_method at /home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py:61 for jit. This concrete value was not available in Python because it depends on the value of the argument args[0][0].
The error occurred while tracing the function rngs_random_method at /home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py:61 for jit. This concrete value was not available in Python because it depends on the value of the argument args[0][1].
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions