Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file.
- Rune: Add support for categorical sampling with `Rune.Rng.categorical` (#89, @nirnayroy).
- Nx: Add float16 and bfloat16 support to safetensors I/O, including precise conversions that preserve denormals/NaNs (#84, @six-shot, @tmattio).
- Talon: Allow forcing column types in Talon JSON loader (#104, @nirnayroy)
- Nx: Update comparison and conditional operations to use boolean tensors (#54, @nirnayroy)

## [1.0.0~alpha1] - 2025-10-02

Expand Down
2 changes: 1 addition & 1 deletion kaun/lib/kaun/ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ let scaled_dot_product_attention ?attention_mask ?(dropout = 0.0) ?is_causal
let ones_matrix = ones dtype [| seq_len_q; seq_len_k |] in
let causal_mask = tril ones_matrix in
let causal_mask = reshape mask_shape causal_mask in
let causal_mask = cast uint8 causal_mask in
let causal_mask = cast bool causal_mask in
let neg_inf = scalar dtype (-1e9) in
where causal_mask scores neg_inf)
else scores
Expand Down
4 changes: 2 additions & 2 deletions kaun/lib/kaun/ops.mli
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
val scaled_dot_product_attention :
?attention_mask:(int, Rune.uint8_elt) Rune.t ->
?attention_mask:(bool, Rune.bool_elt) Rune.t ->
?dropout:float ->
?is_causal:bool ->
?scale:float ->
Expand All @@ -23,7 +23,7 @@ val multi_head_attention :
query:(float, 'a) Rune.t ->
?key:(float, 'a) Rune.t ->
?value:(float, 'a) Rune.t ->
?attention_mask:(int, Rune.uint8_elt) Rune.t ->
?attention_mask:(bool, Rune.bool_elt) Rune.t ->
?is_causal:bool ->
?rngs:Rune.Rng.key ->
embed_dim:int ->
Expand Down
4 changes: 2 additions & 2 deletions kaun/test/test_checkpoint.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ let test_save_and_load () =
| Some loaded_w ->
let is_equal = Rune.all (Rune.equal w loaded_w) in
let is_equal_val = Rune.to_array is_equal in
A.check A.bool "weights match" true (is_equal_val.(0) > 0)
A.check A.bool "weights match" true is_equal_val.(0)
| None -> A.fail "weight is not a tensor")
| None -> A.fail "weight not found");

Expand All @@ -37,7 +37,7 @@ let test_save_and_load () =
| Some loaded_b ->
let is_equal = Rune.all (Rune.equal b loaded_b) in
let is_equal_val = Rune.to_array is_equal in
A.check A.bool "bias matches" true (is_equal_val.(0) > 0)
A.check A.bool "bias matches" true is_equal_val.(0)
| None -> A.fail "bias is not a tensor")
| None -> A.fail "bias not found"

Expand Down
12 changes: 6 additions & 6 deletions nx/lib/backend_c/nx_c.ml
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ external caml_pow :
external caml_cmplt :
('a, 'b) ffi_tensor ->
('a, 'b) ffi_tensor ->
(int, Dtype.uint8_elt) ffi_tensor ->
(bool, Dtype.bool_elt) ffi_tensor ->
unit = "caml_nx_cmplt"

external caml_cmpne :
('a, 'b) ffi_tensor ->
('a, 'b) ffi_tensor ->
(int, Dtype.uint8_elt) ffi_tensor ->
(bool, Dtype.bool_elt) ffi_tensor ->
unit = "caml_nx_cmpne"

external caml_xor :
Expand Down Expand Up @@ -95,7 +95,7 @@ external caml_recip : ('a, 'b) ffi_tensor -> ('a, 'b) ffi_tensor -> unit

(* Ternary operation FFI declarations *)
external caml_where :
(int, Dtype.uint8_elt) ffi_tensor ->
(bool, Dtype.bool_elt) ffi_tensor ->
('a, 'b) ffi_tensor ->
('a, 'b) ffi_tensor ->
('a, 'b) ffi_tensor ->
Expand Down Expand Up @@ -325,7 +325,7 @@ let binary_op op_name ffi_op x y =

out

(* Comparison operation that returns uint8 *)
(* Comparison operation that returns bool *)
let comparison_op op_name ffi_op x y =
(* Ensure both inputs have the same shape *)
let x_shape = shape x in
Expand All @@ -342,8 +342,8 @@ let comparison_op op_name ffi_op x y =
let x' = ensure_materializable x in
let y' = ensure_materializable y in

(* Create output tensor with uint8 dtype *)
let out = create_tensor x.context Dtype.uint8 x_shape in
(* Create output tensor with bool dtype *)
let out = create_tensor x.context Dtype.bool x_shape in

(* Convert to FFI tensors *)
let x_ffi = to_ffi_tensor x' in
Expand Down
22 changes: 11 additions & 11 deletions nx/lib/backend_c/nx_c_binary.c
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ static void nx_c_cmplt_f16_kernel(void *x_data, void *y_data, void *z_data,
long x_off, long y_off, long z_off) {
uint16_t *x = (uint16_t *)x_data;
uint16_t *y = (uint16_t *)y_data;
uint8_t *z = (uint8_t *)z_data;
bool *z = (bool *)z_data;
float a = half_to_float(x[x_off]);
float b = half_to_float(y[y_off]);
z[z_off] = a < b ? 1 : 0;
Expand All @@ -728,7 +728,7 @@ BINARY_OP_IMPL(cmplt, uint16_t, f16)
long y_off, long z_off) { \
T *x = (T *)x_data; \
T *y = (T *)y_data; \
uint8_t *z = (uint8_t *)z_data; \
bool *z = (bool *)z_data; \
float a = TO_FLOAT(x[x_off]); \
float b = TO_FLOAT(y[y_off]); \
z[z_off] = OP(a, b); \
Expand All @@ -750,7 +750,7 @@ LOW_PREC_CMP_KERNEL(cmplt, caml_ba_fp8_e5m2, f8e5m2, CMPLT_OP,
long y_off, long z_off) { \
uint8_t *x = (uint8_t *)x_data; \
uint8_t *y = (uint8_t *)y_data; \
uint8_t *z = (uint8_t *)z_data; \
bool *z = (bool *)z_data; \
/* Unpack x value */ \
long x_byte_off = x_off / 2; \
int x_nib_off = x_off % 2; \
Expand Down Expand Up @@ -797,11 +797,11 @@ LOW_PREC_CMP_KERNEL(cmplt, caml_ba_fp8_e5m2, f8e5m2, CMPLT_OP,
}

// Define comparison operators
#define CMPGT_OP(x, y) ((x) > (y) ? 1 : 0)
#define CMPLE_OP(x, y) ((x) <= (y) ? 1 : 0)
#define CMPGE_OP(x, y) ((x) >= (y) ? 1 : 0)
#define CMPEQ_OP(x, y) ((x) == (y) ? 1 : 0)
#define CMPNE_OP(x, y) ((x) != (y) ? 1 : 0)
#define CMPGT_OP(x, y) ((x) > (y) ? true : false)
#define CMPLE_OP(x, y) ((x) <= (y) ? true : false)
#define CMPGE_OP(x, y) ((x) >= (y) ? true : false)
#define CMPEQ_OP(x, y) ((x) == (y) ? true : false)
#define CMPNE_OP(x, y) ((x) != (y) ? true : false)

// Generate int4/uint4 comparison operations
INT4_COMPARISON_OP_IMPL(cmplt, 1, i4, CMPLT_OP)
Expand Down Expand Up @@ -1197,7 +1197,7 @@ static void dispatch_binary_op(value v_x, value v_y, value v_z,
cleanup_ndarray(&z);
}

// Generic dispatch function for comparison operations (output is always uint8)
// Generic dispatch function for comparison operations (output is always bool)
static void dispatch_comparison_op(value v_x, value v_y, value v_z,
const binary_op_table *table,
const char *op_name) {
Expand Down Expand Up @@ -1241,11 +1241,11 @@ static void dispatch_comparison_op(value v_x, value v_y, value v_z,

// Check output is uint8
int kind_z = nx_ba_get_kind(Caml_ba_array_val(v_z_data));
if (kind_z != CAML_BA_UINT8) {
if (kind_z != NX_BA_BOOL) {
cleanup_ndarray(&x);
cleanup_ndarray(&y);
cleanup_ndarray(&z);
caml_failwith("dtype mismatch: comparison output must be uint8");
caml_failwith("dtype mismatch: comparison output must be bool");
}

// Select operation based on input dtype
Expand Down
10 changes: 5 additions & 5 deletions nx/lib/core/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ module type S = sig
val op_pow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Raise [base] to [exponent]. *)

val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
(** Compare [<]. Returns 0 or 1 as uint8. *)
val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (bool, Dtype.bool_elt) t
(** Compare [<]. Returns False or True as bool. *)

val op_cmpne : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
(** Compare [<>]. Returns 0 or 1 as uint8. *)
val op_cmpne : ('a, 'b) t -> ('a, 'b) t -> (bool, Dtype.bool_elt) t
(** Compare [<>]. Returns False or True as bool. *)

val op_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Bitwise XOR. *)
Expand Down Expand Up @@ -110,7 +110,7 @@ module type S = sig
(* Ternary Op *)

val op_where :
(int, Dtype.uint8_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(bool, Dtype.bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Select from [if_true] or [if_false] based on a boolean tensor. *)

(* Reduction Ops *)
Expand Down
40 changes: 23 additions & 17 deletions nx/lib/core/frontend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ module Make (B : Backend_intf.S) = struct
type std_nativeint_t = (nativeint, nativeint_elt) t
type complex32_t = (Complex.t, complex32_elt) t
type complex64_t = (Complex.t, complex64_elt) t
type bool_t = (bool, bool_elt) t

(* Constructor shortcuts *)
let float16 = Float16
Expand All @@ -96,7 +97,7 @@ module Make (B : Backend_intf.S) = struct
| R of int * int (* Range [start, stop) *)
| Rs of int * int * int (* Range with step *)
| A (* All indices *)
| M of (int, uint8_elt) t (* Boolean mask *)
| M of (bool, bool_elt) t (* Boolean mask *)
| N (* New axis *)

(* ───── Basic Tensor Properties ───── *)
Expand Down Expand Up @@ -922,8 +923,8 @@ module Make (B : Backend_intf.S) = struct
B.op_xor a_b b_b

let logical_not (type a b) (a : (a, b) t) =
(* For boolean tensors (uint8), logical not is 1 - x *)
(* But sub doesn't support uint8, so we use XOR with 1 *)
(* For boolean tensors, logical not should flip the bit *)
(* But subtraction isn't supported for bool, so we use XOR with 1 *)
let dt = dtype a in
match dt with
| Dtype.UInt8 | Dtype.Bool | Dtype.UInt4 | Dtype.QUInt8 ->
Expand Down Expand Up @@ -1230,7 +1231,7 @@ module Make (B : Backend_intf.S) = struct

let isinf x =
let dt = dtype x in
if not (Dtype.is_float dt) then zeros (B.context x) Dtype.uint8 (shape x)
if not (Dtype.is_float dt) then zeros (B.context x) Dtype.bool (shape x)
else
let pos_inf_const = B.op_const_scalar (B.context x) Float.infinity dt in
let neg_inf_const =
Expand All @@ -1242,12 +1243,12 @@ module Make (B : Backend_intf.S) = struct

let isnan x =
let dt = dtype x in
if not (Dtype.is_float dt) then zeros (B.context x) Dtype.uint8 (shape x)
if not (Dtype.is_float dt) then zeros (B.context x) Dtype.bool (shape x)
else cmpne x x

let isfinite x =
let dt = dtype x in
if not (Dtype.is_float dt) then ones (B.context x) Dtype.uint8 (shape x)
if not (Dtype.is_float dt) then ones (B.context x) Dtype.bool (shape x)
else logical_not (logical_or (isinf x) (isnan x))

let lerp start_tensor end_tensor weight =
Expand Down Expand Up @@ -1593,7 +1594,7 @@ module Make (B : Backend_intf.S) = struct
if not can_broadcast then
(* If shapes can't be broadcast, arrays are not equal Return a scalar
False (0) *)
zeros (B.context x) Dtype.uint8 [||]
zeros (B.context x) Dtype.bool [||]
else
(* Check element-wise equality and then check if all are true *)
let eq_result = equal x y in
Expand Down Expand Up @@ -3534,8 +3535,8 @@ module Make (B : Backend_intf.S) = struct
are NOT differentiable. They use unsafe_get to materialize values. *)

(* Forward declaration for mutual recursion *)
let nonzero_indices_only (condition : (int, uint8_elt) t) =
(* Special version for compress that only returns indices for uint8 masks *)
let nonzero_indices_only (condition : (bool, bool_elt) t) =
(* Special version for compress that only returns indices for boolean masks *)
let total = numel condition in
let cond_flat = reshape [| total |] condition in

Expand All @@ -3556,13 +3557,13 @@ module Make (B : Backend_intf.S) = struct
let idx = ref 0 in
for i = 0 to total - 1 do
let elem_val = unsafe_get [ i ] cond_flat in
if elem_val <> 0 then (
if elem_val then (
set_item [ !idx ] (Int32.of_int i) indices;
incr idx)
done;
[| indices |]

let compress ?axis ~(condition : (int, uint8_elt) t) t =
let compress ?axis ~(condition : (bool, bool_elt) t) t =
match axis with
| None ->
(* Flatten and compress *)
Expand Down Expand Up @@ -3591,9 +3592,9 @@ module Make (B : Backend_intf.S) = struct
(numel condition) axis axis_size);

(* Get indices where condition is true *)
let true_indices =
nonzero_indices_only (reshape [| axis_size |] condition)
in
let cond_1d = reshape [| axis_size |] condition in
let true_indices = nonzero_indices_only cond_1d in

if Array.length true_indices = 0 || numel true_indices.(0) = 0 then (
(* No true values - return empty tensor *)
let new_shape = Array.copy (shape t) in
Expand All @@ -3620,7 +3621,7 @@ module Make (B : Backend_intf.S) = struct
let total = numel mask in
let mask_flat = reshape [| total |] mask in

(* Count non-zeros - mask is uint8 (0 or 1) *)
(* Count non-zeros - mask is boolean (true or false) *)
let n_nonzero =
let sum_result = sum (astype Int32 mask_flat) in
let scalar_val = squeeze sum_result |> unsafe_get [] in
Expand Down Expand Up @@ -3651,7 +3652,7 @@ module Make (B : Backend_intf.S) = struct
let is_nonzero_tensor = not_equal elem zero_scalar in
(* We need to materialize this to iterate - this breaks
differentiability *)
let is_nonzero = unsafe_get [] is_nonzero_tensor <> 0 in
let is_nonzero = unsafe_get [] is_nonzero_tensor <> false in

if is_nonzero then (
(* Store coordinates *)
Expand Down Expand Up @@ -7661,7 +7662,12 @@ module Make (B : Backend_intf.S) = struct
shape_for_arange.(ndim_expanded - 1) <- num_classes;
let arange_b = reshape shape_for_arange arange_x in

cmpeq index_expanded arange_b (* Broadcasts to one-hot mask *)
(* Broadcasts to one-hot mask *)
let bool_to_uint (x : (bool, bool_elt) t) : (int, uint8_elt) t =
cast Dtype.uint8 x
in
let bool_tensor = cmpeq index_expanded arange_b in
bool_to_uint bool_tensor

(** Internal N-Dimensional max unpooling. *)
let max_unpool_nd ~kernel_size ?stride ?dilation ~padding_spec
Expand Down
Loading
Loading