@@ -19,17 +19,65 @@ using Reactant_jll: Reactant_jll
19
19
using .. TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
20
20
21
21
using LinearAlgebra
22
+ using LinearAlgebra. BLAS: @blasfunc
22
23
using Libdl: Libdl
23
24
24
25
function __init__ ()
25
26
if Reactant_jll. is_available ()
26
27
libblastrampoline_handle = Libdl. dlopen (LinearAlgebra. BLAS. libblas)
27
28
29
+ # notes:
30
+ # - straight names (e.g. `sgeqrf_`) are LAPACK (i.e. Fortran) symbols
31
+ # - they should finish WITH `_` as the 32-bit symbols are `sgeqrf_` and 64-bit symbols are `sgeqrf_64_`
32
+ # - LAPACKE names (e.g. `LAPACKE_sgeqrf_`) are LAPACKE (i.e. C) symbols and WITHOUT explicit workspace
33
+ # - they should finish WITHOUT `_` as the 32-bit symbols are `LAPACKE_sgeqrf` and 64-bit symbols are `LAPACKE_sgeqrf64_`
34
+ # - LAPACKE names with `_work` suffix (e.g. `LAPACKE_sgeqrf_work`) are LAPACKE symbols WITH explicit workspace
28
35
for (cname, enzymexla_name) in [
29
- (LinearAlgebra. BLAS. @blasfunc (sgetrf_), :enzymexla_lapack_sgetrf_ ),
30
- (LinearAlgebra. BLAS. @blasfunc (dgetrf_), :enzymexla_lapack_dgetrf_ ),
31
- (LinearAlgebra. BLAS. @blasfunc (cgetrf_), :enzymexla_lapack_cgetrf_ ),
32
- (LinearAlgebra. BLAS. @blasfunc (zgetrf_), :enzymexla_lapack_zgetrf_ ),
36
+ # getrf
37
+ (@blasfunc (sgetrf_), :enzymexla_lapack_sgetrf_ ),
38
+ (@blasfunc (dgetrf_), :enzymexla_lapack_dgetrf_ ),
39
+ (@blasfunc (cgetrf_), :enzymexla_lapack_cgetrf_ ),
40
+ (@blasfunc (zgetrf_), :enzymexla_lapack_zgetrf_ ),
41
+ # geqrf
42
+ (@blasfunc (sgeqrf_), :enzymexla_lapack_sgeqrf_ ),
43
+ (@blasfunc (dgeqrf_), :enzymexla_lapack_dgeqrf_ ),
44
+ (@blasfunc (cgeqrf_), :enzymexla_lapack_cgeqrf_ ),
45
+ (@blasfunc (zgeqrf_), :enzymexla_lapack_zgeqrf_ ),
46
+ (@blasfunc (LAPACKE_sgeqrf), :enzymexla_lapacke_sgeqrf_ ),
47
+ (@blasfunc (LAPACKE_dgeqrf), :enzymexla_lapacke_dgeqrf_ ),
48
+ (@blasfunc (LAPACKE_cgeqrf), :enzymexla_lapacke_cgeqrf_ ),
49
+ (@blasfunc (LAPACKE_zgeqrf), :enzymexla_lapacke_zgeqrf_ ),
50
+ (@blasfunc (LAPACKE_sgeqrf_work), :enzymexla_lapacke_sgeqrf_work_ ),
51
+ (@blasfunc (LAPACKE_dgeqrf_work), :enzymexla_lapacke_dgeqrf_work_ ),
52
+ (@blasfunc (LAPACKE_cgeqrf_work), :enzymexla_lapacke_cgeqrf_work_ ),
53
+ (@blasfunc (LAPACKE_zgeqrf_work), :enzymexla_lapacke_zgeqrf_work_ ),
54
+ # geqrt
55
+ (@blasfunc (sgeqrt_), :enzymexla_lapack_sgeqrt_ ),
56
+ (@blasfunc (dgeqrt_), :enzymexla_lapack_dgeqrt_ ),
57
+ (@blasfunc (cgeqrt_), :enzymexla_lapack_cgeqrt_ ),
58
+ (@blasfunc (zgeqrt_), :enzymexla_lapack_zgeqrt_ ),
59
+ (@blasfunc (LAPACKE_sgeqrt), :enzymexla_lapacke_sgeqrt_ ),
60
+ (@blasfunc (LAPACKE_dgeqrt), :enzymexla_lapacke_dgeqrt_ ),
61
+ (@blasfunc (LAPACKE_cgeqrt), :enzymexla_lapacke_cgeqrt_ ),
62
+ (@blasfunc (LAPACKE_zgeqrt), :enzymexla_lapacke_zgeqrt_ ),
63
+ (@blasfunc (LAPACKE_sgeqrt_work), :enzymexla_lapacke_sgeqrt_work_ ),
64
+ (@blasfunc (LAPACKE_dgeqrt_work), :enzymexla_lapacke_dgeqrt_work_ ),
65
+ (@blasfunc (LAPACKE_cgeqrt_work), :enzymexla_lapacke_cgeqrt_work_ ),
66
+ (@blasfunc (LAPACKE_zgeqrt_work), :enzymexla_lapacke_zgeqrt_work_ ),
67
+ # orgqr
68
+ (@blasfunc (sorgqr_), :enzymexla_lapack_sorgqr_ ),
69
+ (@blasfunc (dorgqr_), :enzymexla_lapack_dorgqr_ ),
70
+ (@blasfunc (LAPACKE_sorgqr), :enzymexla_lapacke_sorgqr_ ),
71
+ (@blasfunc (LAPACKE_dorgqr), :enzymexla_lapacke_dorgqr_ ),
72
+ (@blasfunc (LAPACKE_sorgqr_work), :enzymexla_lapacke_sorgqr_work_ ),
73
+ (@blasfunc (LAPACKE_dorgqr_work), :enzymexla_lapacke_dorgqr_work_ ),
74
+ # ungqr
75
+ (@blasfunc (cungqr_), :enzymexla_lapack_cungqr_ ),
76
+ (@blasfunc (zungqr_), :enzymexla_lapack_zungqr_ ),
77
+ (@blasfunc (LAPACKE_cungqr), :enzymexla_lapacke_cungqr_ ),
78
+ (@blasfunc (LAPACKE_zungqr), :enzymexla_lapacke_zungqr_ ),
79
+ (@blasfunc (LAPACKE_cungqr_work), :enzymexla_lapacke_cungqr_work_ ),
80
+ (@blasfunc (LAPACKE_zungqr_work), :enzymexla_lapacke_zungqr_work_ ),
33
81
]
34
82
sym = Libdl. dlsym (libblastrampoline_handle, cname)
35
83
@ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
0 commit comments