Skip to content

Conversation

@infalmo
Copy link

@infalmo infalmo commented Oct 30, 2025

Follow up to #33001. Lmk if any test updates should be made.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @infalmo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the implementation of jax.numpy.eye by removing a static clipping mechanism for its k parameter. This change enables the k parameter to accept dynamic values, enhancing the flexibility and expressiveness of jnp.eye in JAX computations. It is a follow-up to a previous change, streamlining the handling of diagonal offsets.

Highlights

  • Dynamic k parameter for jnp.eye: The k parameter in jax.numpy.eye (and its internal offset parameter) no longer undergoes static range clipping, allowing it to be a dynamic value.
  • Removal of _clip_int_to_valid_range: The helper function _clip_int_to_valid_range and its usages have been removed, simplifying the codebase.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to make the k parameter of jnp.eye non-static by removing a call to _clip_int_to_valid_range which forces its argument to be concrete. The changes look good in principle, but I found a critical issue where np.int32() is used on the offset, which will fail for non-static (tracer) values. I've suggested a fix for this. I also noticed a minor type hint inaccuracy that I've commented on.

@infalmo
Copy link
Author

infalmo commented Oct 30, 2025

In the corresponding test at tests/lax_numpy_test.py, why is k stress-tested on 2**100? This isn't done over testTri, tho both will overflow under the changes made in #33001, no?

  @jtu.sample_product(
    dtype=default_dtypes,
    n=[0, 4],
    m=[None, 0, 1, 3, 4],
    k=[*range(-4, 4), -2**100, 2**100],
  )
  def testEye(self, n, m, k, dtype):
      ....

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/lax_numpy_test.py::LaxBackedNumpyTests::testEye fails with this change; that will have to be addressed.

Also, since the goal of this change is to allow non-static values of offset, this change will need to add tests cases that use non-static offset values: that would catch the issue flagged by the gemini review bot, where a traced offset is passed to np.int32.

@jakevdp jakevdp self-assigned this Oct 30, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Oct 31, 2025

We recently did a similar change to lax._tri, and used this validation for large offsets:

jax/jax/_src/lax/lax.py

Lines 3482 to 3490 in afbd5ff

offset = asarray(core.dimension_as_value(offset))
if not dtypes.issubdtype(offset, np.integer):
raise TypeError(f"offset must be an integer, got {offset!r}")
shape_dtype = lax_utils.int_dtype_for_shape(shape, signed=True)
if (
np.iinfo(offset.dtype).min < np.iinfo(shape_dtype).min
or np.iinfo(offset.dtype).max > np.iinfo(shape_dtype).max
):
shape_dtype = np.dtype(np.int64)

You could use a similar approach here, perhaps by factoring this logic into a helper function to use in both places.

@infalmo infalmo requested a review from jakevdp November 2, 2025 09:31
@infalmo
Copy link
Author

infalmo commented Nov 3, 2025

You could use a similar approach here, perhaps by factoring this logic into a helper function to use in both places.

If lax_utils.int_dtype_for_dim worked on tracers, I could just use the dtype of k as the shape_dtype. Will make this change if you think its okay.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 3, 2025
n=[0, 4],
m=[None, 0, 1, 3, 4],
k=[*range(-4, 4), -2**100, 2**100],
k=[*range(-4, 4), -2**33, 2**33],
Copy link
Collaborator

@jakevdp jakevdp Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error if you pass 2**100? This previously worked correctly, and the test was added due to downstream failures in array_api_tests. I'm not sure whether that case is still relevant, but if so this is a breaking change that may be rolled back.

How hard would it be to continue supporting large ints here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants