-
Notifications
You must be signed in to change notification settings - Fork 72
Description
Summary
When using nnx.List (which stores items with integer dict keys) with Orbax checkpointing, integer keys are converted to strings during save but not converted back during restore. This causes nnx.update() to fail silently because Python dict key lookup is type-strict ("0" != 0).
Minimal Reproduction
"""Minimal script proving the Orbax integer-to-string key conversion bug."""
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
import jax.numpy as jnp
from flax import nnx
import orbax.checkpoint as ocp
from etils import epath
import tempfile
import warnings
warnings.filterwarnings("ignore")
# Step 1: Create an nnx.List with parameters
my_list = nnx.List([
nnx.Param(jnp.array([1.0, 2.0])),
nnx.Param(jnp.array([3.0, 4.0])),
])
state = nnx.state(my_list)
print(f"Original state keys: {list(state.keys())}") # [0, 1] - integers
print(f"First key type: {type(list(state.keys())[0])}") # <class 'int'>
# Step 2: Save and restore checkpoint
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = epath.Path(tmpdir) / "test_checkpoint"
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=False, use_zarr3=False))
ckptr.save(checkpoint_path, state)
restored = ckptr.restore(checkpoint_path)
print(f"Restored keys: {list(restored.keys())}") # ['0', '1'] - strings!
print(f"First key type: {type(list(restored.keys())[0])}") # <class 'str'>
# Step 3: Try to update with nnx.update
new_list = nnx.List([
nnx.Param(jnp.array([99.0, 99.0])),
nnx.Param(jnp.array([99.0, 99.0])),
])
print(f"Before update: {new_list[0].value}") # [99. 99.]
nnx.update(new_list, restored)
print(f"After update: {new_list[0].value}") # [99. 99.] - NOT UPDATED!
print(f"Expected: {restored['0']['value']}") # [1. 2.]
# The bug: keys don't match
print(f"\n'0' == 0: {('0' == 0)}") # False
print(f"'0' in {{0: ...}}: {('0' in {0: 'value'})}") # FalseOutput:
Original state keys: [0, 1]
First key type: <class 'int'>
Restored keys: ['0', '1']
First key type: <class 'str'>
Before update: [99. 99.]
After update: [99. 99.] # ❌ NOT UPDATED!
Expected: [1. 2.]
'0' == 0: False
'0' in {0: ...}: False
Root Cause
In orbax/checkpoint/_src/metadata/tree.py:
Line 148 - During save, all keys are converted to strings:
NestedKeyMetadataEntry(
str(tree_utils.get_key_name(k)), _get_key_metadata_type(k) # str() here!
)Lines 94-97 - During restore, only SequenceKey is converted back to int:
def _keypath_from_key_type(key_name: str, key_type: KeyType) -> Any:
if key_type == KeyType.SEQUENCE:
return jax.tree_util.SequenceKey(int(key_name)) # ✓ Converted back to int
elif key_type == KeyType.DICT:
return jax.tree_util.DictKey(key_name) # ❌ Stays as string!The issue is that nnx.List uses DictKey with integer keys (because it stores items as object attributes), not SequenceKey. So the integer keys get converted to strings and never converted back.
Expected Behavior
When restoring a checkpoint, if a DictKey's key name is a numeric string (e.g., "0", "1"), it should be converted back to an integer so it matches the original key type.
Impact
This bug affects any code using:
nnx.Listwith Orbax checkpointing- Any NNX Module that uses integer dict keys
- The entire MaxText codebase, which uses
nnx.update()for checkpoint loading
Environment
- Orbax version: 0.11.26
- Flax version: 0.12.0
- JAX version: 0.8.0