-
Notifications
You must be signed in to change notification settings - Fork 45
Closed
Labels
good first issueGood for newcomersGood for newcomersoutreachyIssues targeted at Outreachy applicantsIssues targeted at Outreachy applicants
Description
Background
Nx recently added support for extended data types including boolean arrays. However, comparison and conditional operations still use uint8 tensors instead of proper boolean tensors.
Objective
Update comparison and conditional operations to use boolean tensors:
- Change comparison ops (
op_cmplt,op_cmpne, etc.) to return(bool, Dtype.bool_elt) t - Update
op_whereto accept boolean condition tensors - Migrate all mask/condition usage in frontend operations
Design
Current signatures:
val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
val op_where : (int, Dtype.uint8_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) tNew signatures:
val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (bool, Dtype.bool_elt) t
val op_where : (bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) tImplementation Areas
- Backend Interface (
nx/lib/core/backend_intf.ml) - Native Backend (
nx/lib/native/) - C Backend (
nx/lib/c/) - Metal Backend (
nx/lib/metal/) - Frontend (
nx/lib/core/frontend.ml)
Testing
- Ensure all existing tests pass with bool tensors
Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomersoutreachyIssues targeted at Outreachy applicantsIssues targeted at Outreachy applicants