Skip to content

Conversation

@hertschuh
Copy link
Collaborator

By removing the use of ops.nonzero which returns an array of non-predetermined size.

Follow-up to https://github.com/keras-team/keras/pull/21765/files#r2462099871

Fixes #19376

By removing the use of `ops.nonzero` which returns an array of non-predetermined size.

Follow-up to https://github.com/keras-team/keras/pull/21765/files#r2462099871

Fixes keras-team#19376
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical compilation issue within Keras's confusion metrics, specifically when utilizing the JAX backend. The problem stemmed from the use of ops.nonzero, which could return arrays of non-predetermined sizes, hindering JIT compilation. To resolve this, the implementation of _find_max_under_constraint in the confusion metrics has been refactored to use a static-shape masking operation instead of ops.nonzero, ensuring compatibility with JAX's compilation requirements. This change also allowed for a minor simplification in the JAX backend's nonzero function.

Highlights

  • Resolved Compilation Issues for Confusion Metrics: The pull request fixes problems that prevented Keras's confusion metrics from compiling, particularly when using the JAX backend.
  • Replaced Dynamic ops.nonzero with Static Masking: The core change involves replacing the use of ops.nonzero in _find_max_under_constraint with a masking operation. This eliminates the issue of dynamically sized arrays, which are incompatible with JAX's JIT compilation.
  • Backend nonzero Simplification: The JAX backend's nonzero function has been simplified by removing special handling for tracers, as the metric no longer relies on its previous behavior.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@hertschuh
Copy link
Collaborator Author

@danielenricocahall

Thanks!

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the confusion matrix-based metrics to make them JAX-compilable by removing the usage of ops.nonzero. The new implementation in _find_max_under_constraint uses a masking approach which is a good, JIT-compatible pattern. I've suggested a minor refinement to use ops.where for potentially better readability. The associated cleanup in the JAX backend, which removes now-unnecessary special handling for nonzero, is also a good improvement. Additionally, a typo in a docstring was fixed. Overall, these are solid changes that improve the codebase.

Comment on lines +669 to +672
return ops.max(
ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),
initial=0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation is correct and makes the function JIT-compatible. For slightly improved readability, you could consider using ops.where to explicitly mask the dependent tensor. This avoids the implicit boolean-to-float conversion from ops.cast and can make the intent clearer.

Suggested change
return ops.max(
ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),
initial=0,
)
return ops.max(
ops.where(feasible, dependent, 0.0),
initial=0,
)

@codecov-commenter
Copy link

codecov-commenter commented Oct 24, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.63%. Comparing base (18f79d6) to head (baca199).

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21775      +/-   ##
==========================================
- Coverage   82.64%   82.63%   -0.01%     
==========================================
  Files         577      577              
  Lines       59254    59249       -5     
  Branches     9292     9291       -1     
==========================================
- Hits        48968    48963       -5     
  Misses       7903     7903              
  Partials     2383     2383              
Flag Coverage Δ
keras 82.46% <100.00%> (-0.01%) ⬇️
keras-jax 63.34% <100.00%> (-0.01%) ⬇️
keras-numpy 57.57% <100.00%> (-0.01%) ⬇️
keras-openvino 34.31% <0.00%> (+<0.01%) ⬆️
keras-tensorflow 64.11% <100.00%> (+<0.01%) ⬆️
keras-torch 63.65% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 24, 2025
@fchollet fchollet merged commit 10b51ce into keras-team:master Oct 24, 2025
9 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Oct 24, 2025
@danielenricocahall
Copy link
Contributor

@danielenricocahall

Thanks!

Ah I'm sorry! I was too gung ho with my fix, thank you for doing this!

@hertschuh hertschuh deleted the confusion_max branch October 27, 2025 16:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Inconsistent manner of the metric SpecificityAtSensitivity among different backends

5 participants