Skip to content

Replace uint8 Masks with Bool Tensors #54

@tmattio

Description

@tmattio

Background

Nx recently added support for extended data types including boolean arrays. However, comparison and conditional operations still use uint8 tensors instead of proper boolean tensors.

Objective

Update comparison and conditional operations to use boolean tensors:

  • Change comparison ops (op_cmplt, op_cmpne, etc.) to return (bool, Dtype.bool_elt) t
  • Update op_where to accept boolean condition tensors
  • Migrate all mask/condition usage in frontend operations

Design

Current signatures:

val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
val op_where : (int, Dtype.uint8_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

New signatures:

val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (bool, Dtype.bool_elt) t
val op_where : (bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Implementation Areas

  1. Backend Interface (nx/lib/core/backend_intf.ml)
  2. Native Backend (nx/lib/native/)
  3. C Backend (nx/lib/c/)
  4. Metal Backend (nx/lib/metal/)
  5. Frontend (nx/lib/core/frontend.ml)

Testing

  • Ensure all existing tests pass with bool tensors

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issueGood for newcomersoutreachyIssues targeted at Outreachy applicants

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions