Skip to content

Commit 72849b2

Browse files
committed
Register LAPACK(E) symbols
1 parent b126621 commit 72849b2

File tree

1 file changed

+52
-4
lines changed

1 file changed

+52
-4
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,65 @@ using Reactant_jll: Reactant_jll
1919
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
2020

2121
using LinearAlgebra
22+
using LinearAlgebra.BLAS: @blasfunc
2223
using Libdl: Libdl
2324

2425
function __init__()
2526
if Reactant_jll.is_available()
2627
libblastrampoline_handle = Libdl.dlopen(LinearAlgebra.BLAS.libblas)
2728

29+
# notes:
30+
# - straight names (e.g. `sgeqrf_`) are LAPACK (i.e. Fortran) symbols
31+
# - they should finish WITH `_` as the 32-bit symbols are `sgeqrf_` and 64-bit symbols are `sgeqrf_64_`
32+
# - LAPACKE names (e.g. `LAPACKE_sgeqrf_`) are LAPACKE (i.e. C) symbols and WITHOUT explicit workspace
33+
# - they should finish WITHOUT `_` as the 32-bit symbols are `LAPACKE_sgeqrf` and 64-bit symbols are `LAPACKE_sgeqrf64_`
34+
# - LAPACKE names with `_work` suffix (e.g. `LAPACKE_sgeqrf_work`) are LAPACKE symbols WITH explicit workspace
2835
for (cname, enzymexla_name) in [
29-
(LinearAlgebra.BLAS.@blasfunc(sgetrf_), :enzymexla_lapack_sgetrf_),
30-
(LinearAlgebra.BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_),
31-
(LinearAlgebra.BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_),
32-
(LinearAlgebra.BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_),
36+
# getrf
37+
(@blasfunc(sgetrf_), :enzymexla_lapack_sgetrf_),
38+
(@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_),
39+
(@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_),
40+
(@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_),
41+
# geqrf
42+
(@blasfunc(sgeqrf_), :enzymexla_lapack_sgeqrf_),
43+
(@blasfunc(dgeqrf_), :enzymexla_lapack_dgeqrf_),
44+
(@blasfunc(cgeqrf_), :enzymexla_lapack_cgeqrf_),
45+
(@blasfunc(zgeqrf_), :enzymexla_lapack_zgeqrf_),
46+
(@blasfunc(LAPACKE_sgeqrf), :enzymexla_lapacke_sgeqrf_),
47+
(@blasfunc(LAPACKE_dgeqrf), :enzymexla_lapacke_dgeqrf_),
48+
(@blasfunc(LAPACKE_cgeqrf), :enzymexla_lapacke_cgeqrf_),
49+
(@blasfunc(LAPACKE_zgeqrf), :enzymexla_lapacke_zgeqrf_),
50+
(@blasfunc(LAPACKE_sgeqrf_work), :enzymexla_lapacke_sgeqrf_work_),
51+
(@blasfunc(LAPACKE_dgeqrf_work), :enzymexla_lapacke_dgeqrf_work_),
52+
(@blasfunc(LAPACKE_cgeqrf_work), :enzymexla_lapacke_cgeqrf_work_),
53+
(@blasfunc(LAPACKE_zgeqrf_work), :enzymexla_lapacke_zgeqrf_work_),
54+
# geqrt
55+
(@blasfunc(sgeqrt_), :enzymexla_lapack_sgeqrt_),
56+
(@blasfunc(dgeqrt_), :enzymexla_lapack_dgeqrt_),
57+
(@blasfunc(cgeqrt_), :enzymexla_lapack_cgeqrt_),
58+
(@blasfunc(zgeqrt_), :enzymexla_lapack_zgeqrt_),
59+
(@blasfunc(LAPACKE_sgeqrt), :enzymexla_lapacke_sgeqrt_),
60+
(@blasfunc(LAPACKE_dgeqrt), :enzymexla_lapacke_dgeqrt_),
61+
(@blasfunc(LAPACKE_cgeqrt), :enzymexla_lapacke_cgeqrt_),
62+
(@blasfunc(LAPACKE_zgeqrt), :enzymexla_lapacke_zgeqrt_),
63+
(@blasfunc(LAPACKE_sgeqrt_work), :enzymexla_lapacke_sgeqrt_work_),
64+
(@blasfunc(LAPACKE_dgeqrt_work), :enzymexla_lapacke_dgeqrt_work_),
65+
(@blasfunc(LAPACKE_cgeqrt_work), :enzymexla_lapacke_cgeqrt_work_),
66+
(@blasfunc(LAPACKE_zgeqrt_work), :enzymexla_lapacke_zgeqrt_work_),
67+
# orgqr
68+
(@blasfunc(sorgqr_), :enzymexla_lapack_sorgqr_),
69+
(@blasfunc(dorgqr_), :enzymexla_lapack_dorgqr_),
70+
(@blasfunc(LAPACKE_sorgqr), :enzymexla_lapacke_sorgqr_),
71+
(@blasfunc(LAPACKE_dorgqr), :enzymexla_lapacke_dorgqr_),
72+
(@blasfunc(LAPACKE_sorgqr_work), :enzymexla_lapacke_sorgqr_work_),
73+
(@blasfunc(LAPACKE_dorgqr_work), :enzymexla_lapacke_dorgqr_work_),
74+
# ungqr
75+
(@blasfunc(cungqr_), :enzymexla_lapack_cungqr_),
76+
(@blasfunc(zungqr_), :enzymexla_lapack_zungqr_),
77+
(@blasfunc(LAPACKE_cungqr), :enzymexla_lapacke_cungqr_),
78+
(@blasfunc(LAPACKE_zungqr), :enzymexla_lapacke_zungqr_),
79+
(@blasfunc(LAPACKE_cungqr_work), :enzymexla_lapacke_cungqr_work_),
80+
(@blasfunc(LAPACKE_zungqr_work), :enzymexla_lapacke_zungqr_work_),
3381
]
3482
sym = Libdl.dlsym(libblastrampoline_handle, cname)
3583
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(

0 commit comments

Comments
 (0)