Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ All notable changes to this project will be documented in this file.
- Saga: Fix Unigram `token_to_id`/`id_to_token` vocabulary lookups (#117, @RidwanAdebosin)
- Nx: Fix `matrix_rank`/`pinv` Hermitian fast paths to use eigen-decomposition and match NumPy for complex inputs (#96, @six-shot, @tmattio).
- Fehu: Finish clipped value loss support in Fehu.Training (#107, @nirnayroy)
- Nx: Fix complex vdot to conjugate first tensor before multiplication, ensuring correct mathematical behavior (#123, @Arsalaan-Alam)

## [1.0.0~alpha1] - 2025-10-02

Expand Down
3 changes: 1 addition & 2 deletions nx/lib/core/frontend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4546,8 +4546,7 @@ module Make (B : Backend_intf.S) = struct
(* For complex types, conjugate first vector *)
match dtype a with
| (Complex32 | Complex64) when dtype a = dtype b ->
(* TODO: implement conj when available *)
sum (mul flat_a flat_b)
sum (mul (conjugate flat_a) flat_b)
| _ -> sum (mul flat_a flat_b)

let vecdot ?axis x1 x2 =
Expand Down
14 changes: 14 additions & 0 deletions nx/test/test_nx_linalg.ml
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,19 @@ let test_vdot () =
let res2 = Nx.vdot a2 b in
check_t "vdot flatten" [||] [| 4. +. 10. +. 18. +. 16. +. 25. +. 36. |] res2

let test_vdot_complex () =
(* Test complex vdot with conjugation *)
let a = Nx.create Nx.complex32 [| 2 |]
[| Complex.{ re = 1.; im = 2. }; Complex.{ re = 3.; im = 4. } |] in
let b = Nx.create Nx.complex32 [| 2 |]
[| Complex.{ re = 5.; im = 6. }; Complex.{ re = 7.; im = 8. } |] in
let result = Nx.vdot a b in
(* Expected: conj(a) * b = [(1-2i)(5+6i), (3-4i)(7+8i)] = [17-4i, 53-4i] = 70-8i *)
let expected = Complex.{ re = 70.; im = -8. } in
let actual = Nx.item [] result in
check (float 1e-6) "vdot complex real part" expected.re actual.re;
check (float 1e-6) "vdot complex imag part" expected.im actual.im

let test_vdot_mismatch () =
let a = Nx.create Nx.float32 [| 3 |] [| 1.; 2.; 3. |] in
let b = Nx.create Nx.float32 [| 4 |] [| 4.; 5.; 6.; 7. |] in
Expand Down Expand Up @@ -1118,6 +1131,7 @@ let advanced_utility_tests =
let product_tests =
[
("vdot", `Quick, test_vdot);
("vdot complex", `Quick, test_vdot_complex);
("vdot mismatch", `Quick, test_vdot_mismatch);
("vecdot", `Quick, test_vecdot);
("inner", `Quick, test_inner);
Expand Down
Loading