Skip to content

Orbax converts integer dict keys to strings and doesn't convert them back, breaking nnx.update() #2561

@eitanporat

Description

@eitanporat

Summary

When using nnx.List (which stores items with integer dict keys) with Orbax checkpointing, integer keys are converted to strings during save but not converted back during restore. This causes nnx.update() to fail silently because Python dict key lookup is type-strict ("0" != 0).

Minimal Reproduction

"""Minimal script proving the Orbax integer-to-string key conversion bug."""

import os
os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp
from flax import nnx
import orbax.checkpoint as ocp
from etils import epath
import tempfile
import warnings
warnings.filterwarnings("ignore")

# Step 1: Create an nnx.List with parameters
my_list = nnx.List([
    nnx.Param(jnp.array([1.0, 2.0])),
    nnx.Param(jnp.array([3.0, 4.0])),
])

state = nnx.state(my_list)
print(f"Original state keys: {list(state.keys())}")  # [0, 1] - integers
print(f"First key type: {type(list(state.keys())[0])}")  # <class 'int'>

# Step 2: Save and restore checkpoint
with tempfile.TemporaryDirectory() as tmpdir:
    checkpoint_path = epath.Path(tmpdir) / "test_checkpoint"
    ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=False, use_zarr3=False))
    ckptr.save(checkpoint_path, state)

    restored = ckptr.restore(checkpoint_path)
    print(f"Restored keys: {list(restored.keys())}")  # ['0', '1'] - strings!
    print(f"First key type: {type(list(restored.keys())[0])}")  # <class 'str'>

    # Step 3: Try to update with nnx.update
    new_list = nnx.List([
        nnx.Param(jnp.array([99.0, 99.0])),
        nnx.Param(jnp.array([99.0, 99.0])),
    ])

    print(f"Before update: {new_list[0].value}")  # [99. 99.]
    nnx.update(new_list, restored)
    print(f"After update: {new_list[0].value}")  # [99. 99.] - NOT UPDATED!
    print(f"Expected: {restored['0']['value']}")  # [1. 2.]

    # The bug: keys don't match
    print(f"\n'0' == 0: {('0' == 0)}")  # False
    print(f"'0' in {{0: ...}}: {('0' in {0: 'value'})}")  # False

Output:

Original state keys: [0, 1]
First key type: <class 'int'>
Restored keys: ['0', '1']
First key type: <class 'str'>
Before update: [99. 99.]
After update: [99. 99.]  # ❌ NOT UPDATED!
Expected: [1. 2.]

'0' == 0: False
'0' in {0: ...}: False

Root Cause

In orbax/checkpoint/_src/metadata/tree.py:

Line 148 - During save, all keys are converted to strings:

NestedKeyMetadataEntry(
    str(tree_utils.get_key_name(k)), _get_key_metadata_type(k)  # str() here!
)

Lines 94-97 - During restore, only SequenceKey is converted back to int:

def _keypath_from_key_type(key_name: str, key_type: KeyType) -> Any:
  if key_type == KeyType.SEQUENCE:
    return jax.tree_util.SequenceKey(int(key_name))  # ✓ Converted back to int
  elif key_type == KeyType.DICT:
    return jax.tree_util.DictKey(key_name)  # ❌ Stays as string!

The issue is that nnx.List uses DictKey with integer keys (because it stores items as object attributes), not SequenceKey. So the integer keys get converted to strings and never converted back.

Expected Behavior

When restoring a checkpoint, if a DictKey's key name is a numeric string (e.g., "0", "1"), it should be converted back to an integer so it matches the original key type.

Impact

This bug affects any code using:

  • nnx.List with Orbax checkpointing
  • Any NNX Module that uses integer dict keys
  • The entire MaxText codebase, which uses nnx.update() for checkpoint loading

Environment

  • Orbax version: 0.11.26
  • Flax version: 0.12.0
  • JAX version: 0.8.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions