Skip to content

Fix overloaded type signature for jax.numpy.where. #28314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

carlosgmartin
Copy link
Contributor

The current type annotation for jax.numpy.where causes type errors that require assertions to silence. Here is a simple example:

test.py:

import jax
from jax import numpy as jnp, random

def categorical(key, logits, axis=-1, shape=None, where=None):
    if where is not None:
        logits = jnp.where(where, logits, -jnp.inf)
        # assert isinstance(logits, jax.Array)
    return random.categorical(key, logits, axis, shape)

def argmax(x, *args, where=None, **kwargs):
    # assert isinstance(x, jax.Array)
    if where is None:
        return x.argmax(*args, **kwargs)
    else:
        return jnp.where(where, x, -jnp.inf).argmax(*args, **kwargs)
$ pyright test.py
/Users/carlos/Downloads/test.py
  /Users/carlos/Downloads/test.py:8:36 - error: Argument of type "Array | tuple[Array, ...] | Unknown" cannot be assigned to parameter "logits" of type "RealArray" in function "categorical"
    Type "Array | tuple[Array, ...] | Unknown" is not assignable to type "RealArray"
      Type "tuple[Array, ...]" is not assignable to type "RealArray"
        "tuple[Array, ...]" is not assignable to "Array"
        "tuple[Array, ...]" is not assignable to "ndarray[Unknown, Unknown]"
        "tuple[Array, ...]" is not assignable to "bool_"
        "tuple[Array, ...]" is not assignable to "number[Unknown]"
        "tuple[Array, ...]" is not assignable to "bool"
        "tuple[Array, ...]" is not assignable to "int"
    ... (reportArgumentType)
  /Users/carlos/Downloads/test.py:15:46 - error: Cannot access attribute "argmax" for class "tuple[Array, ...]"
    Attribute "argmax" is unknown (reportAttributeAccessIssue)
2 errors, 0 warnings, 0 informations 

Specifically, this is caused by the third overload here and here, which is superfluous (see the example here).

This commit removes that overload.

Coincidentally, removing the overload unmasked a couple of minor type issues, which are also fixed in this commit.

@jakevdp jakevdp self-assigned this Apr 28, 2025
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 28, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 28, 2025

Unfortunately we can't merge this because it leads to a bunch of pytype errors. It looks like pytype requires the removed overload in order to function correctly.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Apr 28, 2025

@jakevdp pytype doesn't seem to be run by pre-commit, so I don't see that error:

$ pre-commit run --all
check python ast.........................................................Passed
check for merge conflicts................................................Passed
check toml...............................................................Passed
check yaml...............................................................Passed
fix end of files.........................................................Passed
debug statements (python)................................................Passed
trim trailing whitespace.................................................Passed
ruff.....................................................................Passed
mypy.....................................................................Passed
jupytext.................................................................Passed
$

Should pytype be added to pre-commit?

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 28, 2025

We've experimented with adding pytype to the github pre-commit, but it's too slow.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 28, 2025

Static type checking is a mess (not just in JAX, but in Python in general). Multiple implementations, each of which diverge in different directions at the edges.

If you're running into issues, the best solution is probably to use your tool-specific ignore statement and move on.

@jakevdp jakevdp assigned superbobry and unassigned jakevdp Apr 28, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 28, 2025

Re-assigning to @superbobry who may have ideas on how to thread a path through all the implementations here.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Apr 28, 2025

Do you know what command and options pytype is being invoked with, so I can reproduce it locally and see the errors?

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 28, 2025

Do you know what command and options pytype is being invoked with, so I can reproduce it locally and see the errors?

I don't know – @superbobry may have ideas. Note that the errors I'm seeing are not in JAX itself, but in downstream libraries that use JAX.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 28, 2025

An example of the types of errors this is leading to:

.../jax/experimental/pallas/ops/tpu/all_gather.py:58:3: error: in get_neighbor: bad return type [bad-return-type]
           Expected: tuple[jax._src.basearray.Array, ...]
  Actually returned: tuple[tuple[jax._src.basearray.Array, ...], ...]

  return tuple(mesh_index)
  ~~~~~~~~~~~~~~~~~~~~~~~~

From this function:

def get_neighbor(
idx: jax.Array, mesh: jax.sharding.Mesh, axis_name: str, *, direction: str
) -> tuple[jax.Array, ...]:
"""Helper function that computes the mesh indices of a neighbor."""
axis_names = mesh.axis_names
which_axis = axis_names.index(axis_name)
mesh_index = [
idx if i == which_axis else lax.axis_index(a)
for i, a in enumerate(axis_names)
]
axis_size = lax.axis_size(axis_name)
if direction == "right":
next_idx = lax.rem(idx + 1, axis_size)
else:
left = idx - 1
next_idx = jnp.where(left < 0, left + axis_size, left)
mesh_index[which_axis] = next_idx
return tuple(mesh_index)

It looks like pytype is not recognizing the correct overload in this context.

@superbobry
Copy link
Collaborator

I suspect fixing this for pytype might be a question of reordering the overloads :)

Will experiment and report back.

@superbobry superbobry added pull ready Ready for copybara import and testing and removed kokoro:force-run pull ready Ready for copybara import and testing labels May 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants