Skip to content

PyTreeDef comparison fails type checking for jax==0.6.1 #29037

Open
@mfschubert

Description

@mfschubert

Description

Starting with jax 0.6.1, it seems that the ability to compare PyTreeDef is not recognized by mypy. Specifically, the following file:

from typing import Any
from jax import tree_util

def test_pytreedef_equal(a: Any, b: Any) -> bool:
  return tree_util.tree_structure(a) == tree_util.tree_structure(b)

will fail mypy type validation with the following error:

test.py:6: error: Unsupported left operand type for == (PyTreeDef?)  [operator]
Found 1 error in 1 file (checked 1 source file)

This is demonstrated in the following colab.

The error does not occur for jax 0.6.0 and earlier.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.6.1
jaxlib: 0.6.1
numpy: 2.0.2
python: 3.11.12 (main, Apr 9 2025, 08:55:54) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='44b407801e70', release='6.1.123+', version='#1 SMP PREEMPT_DYNAMIC Sun Mar 30 16:01:29 UTC 2025', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions