-
Notifications
You must be signed in to change notification settings - Fork 759
Open
Description
I have the latest flax installed with pip install -U git+https://github.com/google/flax.git (specifically 74985b2). nnx.tabulate is not working because it seems to require shapes to be concrete.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): WSL2 on Windows Pro
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:
Name: flax
Version: 0.12.0
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author:
Author-email: Flax team <flax-dev@google.com>
License:
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: clu, evosax, jax-ai-stack, jraphx
---
Name: jax
Version: 0.8.0
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: chex, clu, evosax, flax, jax-ai-stack, jaxloudnorm, jraphx, optax, orbax-checkpoint, orbax-export
---
Name: jaxlib
Version: 0.8.0
Summary: XLA library for JAX
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: ml_dtypes, numpy, scipy
Required-by: chex, clu, jax, jraphx, optax, orbax-export
- Python version: 3.12
- GPU/TPU model and memory: RTX 4080
- CUDA version (if applicable): 12
Problem you have encountered:
MRE:
from jax import numpy as jnp
from flax import nnx
class Net(nnx.Module):
def __init__(self):
self.rngs = nnx.Rngs(0)
def __call__(self, x):
return self.rngs.uniform((x.shape[0], 10))
if __name__ == '__main__':
net = Net()
x = jnp.zeros((4, 8))
print("running forward pass")
y = net(x)
print("running tabulate")
print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
print("all done")What you expected to happen:
With flax 0.12.0 from PyPI, the output is
running forward pass
running tabulate
Net Summary
┏━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ path ┃ type ┃ inputs ┃ outputs ┃ RngState ┃
┡━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ │ Net │ float32[4,8] │ float32[4,10] │ 2 (12 B) │
├──────────────┼──────┼──────────────┼───────────────┼──────────┤
│ rngs/uniform │ Rngs │ - 4 │ float32[4,10] │ 2 (12 B) │
│ │ │ - 10 │ │ │
├──────────────┼──────┼──────────────┼───────────────┼──────────┤
│ │ │ │ Total │ 2 (12 B) │
└──────────────┴──────┴──────────────┴───────────────┴──────────┘
Total Parameters: 2 (12 B)
all done
Logs, error messages, etc:
Output from MRE above:
running forward pass
running tabulate
Traceback (most recent call last):
File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2025.2/scratches/scratch_33.py", line 19, in <module>
print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 385, in tabulate
jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs)
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 475, in trace
traced = self.jitted_fn.trace(*pure_args, **pure_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
out = self.f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 215, in wrapper
return f(obj, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2025.2/scratches/scratch_33.py", line 10, in __call__
return self.rngs.uniform((x.shape[0], 10))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 447, in __call__
pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
out = self.f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 215, in wrapper
return f(obj, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py", line 62, in rngs_random_method
return random_f(self(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 425, in uniform
shape = core.canonicalize_shape(shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>, JitTracer<~int32[]>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function rngs_random_method at /home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py:61 for jit. This concrete value was not available in Python because it depends on the value of the argument args[0][0].
The error occurred while tracing the function rngs_random_method at /home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py:61 for jit. This concrete value was not available in Python because it depends on the value of the argument args[0][1].
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Metadata
Metadata
Assignees
Labels
No labels