Skip to content

Fix complex numbers in Nx.vdot #76

@tmattio

Description

@tmattio

Why it matters

The dot product of a vector with itself should give a non-negative number. For complex tensors this only happens if we conjugate the first vector before multiplying. NumPy (and most linear algebra textbooks) do this by default, because it keeps norms and inner products behaving nicely. Our Nx.vdot skips that conjugation step at nx/lib/core/frontend.ml, so complex dot products are wrong.

How to see the bug

let open Nx in
let a =
  create complex32 [| 2 |]
    [| Complex.{ re = 1.; im = 2. }; Complex.{ re = 3.; im = 4. } |]
in
let b =
  create complex32 [| 2 |]
    [| Complex.{ re = 5.; im = 6. }; Complex.{ re = 7.; im = 8. } |]
in
let result = vdot a b in
to_array result

Expected (same as NumPy's vdot): real part 70. and imaginary part -8.. Current code returns -18. and 68. because it multiplies a * b directly instead of conj(a) * b.

Your task

  • Update the complex-number branch of vdot in nx/lib/core/frontend.ml so it conjugates the first tensor before multiplying.
  • Add a regression test in nx/test/test_nx_linalg.ml that covers complex inputs (compare against Complex.{ re; im }).

Tips

  • Helper functions live nearby: see real, imag, and complex around line 5930.
  • Conjugating a complex number means keeping the real part and negating the imaginary part.

Done when

  • The new test fails before your change and passes after it.
  • dune runtest nx/test:test_nx_linalg succeeds.

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issueGood for newcomersoutreachyIssues targeted at Outreachy applicants

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions