Fix integer dict key restoration in metadata serialization (#2587) #2619
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes an issue where integer dict keys and array index keys in a
PyTree path were being converted to strings during metadata serialization
and could not be reconstructed accurately upon restore. This caused a
mismatch between saved and restored metadata, leading to failures when
loading checkpoints for structures containing numeric keys.
Related issue: #2587
Root Cause
NestedKeyMetadataEntry.from_jsonalways interpretednested_key_nameas a string. As a result, numeric key values like
"0"or"1"were notconverted back into integers, even though they originated from numeric
PyTree paths.
Solution
nested_key_nameback tointwhen the JSON value is a digit-only string.KeyMetadataEntry.buildto correctly preserve integer key types.Testing
A new test was added:
orbax/checkpoint/tests/test_export_type_conversions.pyThis integration test: