-
Notifications
You must be signed in to change notification settings - Fork 3k
Fix overloaded type signature for jax.numpy.where. #28314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Unfortunately we can't merge this because it leads to a bunch of pytype errors. It looks like pytype requires the removed overload in order to function correctly. |
@jakevdp pytype doesn't seem to be run by pre-commit, so I don't see that error:
Should pytype be added to pre-commit? |
We've experimented with adding pytype to the github pre-commit, but it's too slow. |
Static type checking is a mess (not just in JAX, but in Python in general). Multiple implementations, each of which diverge in different directions at the edges. If you're running into issues, the best solution is probably to use your tool-specific ignore statement and move on. |
Re-assigning to @superbobry who may have ideas on how to thread a path through all the implementations here. |
Do you know what command and options pytype is being invoked with, so I can reproduce it locally and see the errors? |
I don't know – @superbobry may have ideas. Note that the errors I'm seeing are not in JAX itself, but in downstream libraries that use JAX. |
An example of the types of errors this is leading to:
From this function: jax/jax/experimental/pallas/ops/tpu/all_gather.py Lines 41 to 58 in 9fa4fb7
It looks like pytype is not recognizing the correct overload in this context. |
I suspect fixing this for pytype might be a question of reordering the overloads :) Will experiment and report back. |
The current type annotation for jax.numpy.where causes type errors that require assertions to silence. Here is a simple example:
test.py
:Specifically, this is caused by the third overload here and here, which is superfluous (see the example here).
This commit removes that overload.
Coincidentally, removing the overload unmasked a couple of minor type issues, which are also fixed in this commit.