-
Notifications
You must be signed in to change notification settings - Fork 44
Closed
Labels
good first issueGood for newcomersGood for newcomersoutreachyIssues targeted at Outreachy applicantsIssues targeted at Outreachy applicants
Description
Why it matters
The dot product of a vector with itself should give a non-negative number. For complex tensors this only happens if we conjugate the first vector before multiplying. NumPy (and most linear algebra textbooks) do this by default, because it keeps norms and inner products behaving nicely. Our Nx.vdot skips that conjugation step at nx/lib/core/frontend.ml, so complex dot products are wrong.
How to see the bug
let open Nx in
let a =
create complex32 [| 2 |]
[| Complex.{ re = 1.; im = 2. }; Complex.{ re = 3.; im = 4. } |]
in
let b =
create complex32 [| 2 |]
[| Complex.{ re = 5.; im = 6. }; Complex.{ re = 7.; im = 8. } |]
in
let result = vdot a b in
to_array resultExpected (same as NumPy's vdot): real part 70. and imaginary part -8.. Current code returns -18. and 68. because it multiplies a * b directly instead of conj(a) * b.
Your task
- Update the complex-number branch of
vdotinnx/lib/core/frontend.mlso it conjugates the first tensor before multiplying. - Add a regression test in
nx/test/test_nx_linalg.mlthat covers complex inputs (compare againstComplex.{ re; im }).
Tips
- Helper functions live nearby: see
real,imag, andcomplexaround line 5930. - Conjugating a complex number means keeping the real part and negating the imaginary part.
Done when
- The new test fails before your change and passes after it.
dune runtest nx/test:test_nx_linalgsucceeds.
Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomersoutreachyIssues targeted at Outreachy applicantsIssues targeted at Outreachy applicants