Skip to content

Commit

Permalink
Port Eigenvalue Decompositions to XLA's FFI
Browse files Browse the repository at this point in the history
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 659492696
  • Loading branch information
pparuzel authored and jax authors committed Aug 5, 2024
1 parent 9b35b76 commit b2a469b
Show file tree
Hide file tree
Showing 8 changed files with 676 additions and 23 deletions.
2 changes: 2 additions & 0 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cc_library(
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/types:span",
],
)

Expand All @@ -64,6 +65,7 @@ pybind_extension(
pytype_srcs = [
"_lapack/__init__.pyi",
"_lapack/svd.pyi",
"_lapack/eig.pyi",
],
deps = [
":lapack_kernels",
Expand Down
5 changes: 5 additions & 0 deletions jaxlib/cpu/_lapack/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import eig as eig
from . import svd as svd


Expand Down Expand Up @@ -53,6 +54,8 @@ def cgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def gesdd_iwork_size_ffi(m: int, n: int) -> int: ...
def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def heevd_rwork_size_ffi(n: int) -> int: ...
def heevd_work_size_ffi(n: int) -> int: ...
def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ...
Expand All @@ -62,4 +65,6 @@ def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def syevd_iwork_size_ffi(n: int) -> int: ...
def syevd_work_size_ffi(n: int) -> int: ...
def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
21 changes: 21 additions & 0 deletions jaxlib/cpu/_lapack/eig.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from typing import ClassVar


class ComputationMode(enum.Enum):
kComputeEigenvectors: ClassVar[ComputationMode] = ...
kNoEigenvectors: ClassVar[ComputationMode] = ...
8 changes: 8 additions & 0 deletions jaxlib/cpu/cpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_ssyevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dsyevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cheevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zheevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi);

#undef JAX_CPU_REGISTER_HANDLER

Expand Down
44 changes: 43 additions & 1 deletion jaxlib/cpu/lapack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <complex>
#include <cstdint>

#include "nanobind/nanobind.h"
#include "jaxlib/cpu/lapack_kernels.h"
Expand Down Expand Up @@ -52,6 +52,14 @@ lapack_int GesddGetRealWorkspaceSize(lapack_int m, lapack_int n,
return svd::GetRealWorkspaceSize(m, n, mode);
}

// Due to enforced kComputeEigenvectors, this assumes a larger workspace size.
// Could be improved to more accurately estimate the expected size based on the
// eig::ComputationMode value.
template <lapack_int (&f)(int64_t, eig::ComputationMode)>
inline constexpr auto BoundWithEigvecs = +[](lapack_int n) {
return f(n, eig::ComputationMode::kComputeEigenvectors);
};

void GetLapackKernelsFromScipy() {
static bool initialized = false; // Protected by GIL
if (initialized) return;
Expand Down Expand Up @@ -128,11 +136,25 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<RealSyevd<double>>(lapack_ptr("dsyevd"));
AssignKernelFn<ComplexHeevd<std::complex<float>>>(lapack_ptr("cheevd"));
AssignKernelFn<ComplexHeevd<std::complex<double>>>(lapack_ptr("zheevd"));
AssignKernelFn<EigenvalueDecompositionSymmetric<DataType::F32>>(
lapack_ptr("ssyevd"));
AssignKernelFn<EigenvalueDecompositionSymmetric<DataType::F64>>(
lapack_ptr("dsyevd"));
AssignKernelFn<EigenvalueDecompositionHermitian<DataType::C64>>(
lapack_ptr("cheevd"));
AssignKernelFn<EigenvalueDecompositionHermitian<DataType::C128>>(
lapack_ptr("zheevd"));

AssignKernelFn<RealGeev<float>>(lapack_ptr("sgeev"));
AssignKernelFn<RealGeev<double>>(lapack_ptr("dgeev"));
AssignKernelFn<ComplexGeev<std::complex<float>>>(lapack_ptr("cgeev"));
AssignKernelFn<ComplexGeev<std::complex<double>>>(lapack_ptr("zgeev"));
AssignKernelFn<EigenvalueDecomposition<DataType::F32>>(lapack_ptr("sgeev"));
AssignKernelFn<EigenvalueDecomposition<DataType::F64>>(lapack_ptr("dgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<DataType::C64>>(
lapack_ptr("cgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<DataType::C128>>(
lapack_ptr("zgeev"));

AssignKernelFn<RealGees<float>>(lapack_ptr("sgees"));
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
Expand Down Expand Up @@ -246,6 +268,14 @@ nb::dict Registrations() {
dict["lapack_dgesdd_ffi"] = EncapsulateFunction(lapack_dgesdd_ffi);
dict["lapack_cgesdd_ffi"] = EncapsulateFunction(lapack_cgesdd_ffi);
dict["lapack_zgesdd_ffi"] = EncapsulateFunction(lapack_zgesdd_ffi);
dict["lapack_ssyevd_ffi"] = EncapsulateFunction(lapack_ssyevd_ffi);
dict["lapack_dsyevd_ffi"] = EncapsulateFunction(lapack_dsyevd_ffi);
dict["lapack_cheevd_ffi"] = EncapsulateFunction(lapack_cheevd_ffi);
dict["lapack_zheevd_ffi"] = EncapsulateFunction(lapack_zheevd_ffi);
dict["lapack_sgeev_ffi"] = EncapsulateFunction(lapack_sgeev_ffi);
dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi);
dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi);
dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi);

return dict;
}
Expand All @@ -256,12 +286,16 @@ NB_MODULE(_lapack, m) {
m.def("registrations", &Registrations);
// Submodules
auto svd = m.def_submodule("svd");
auto eig = m.def_submodule("eig");
// Enums
nb::enum_<svd::ComputationMode>(svd, "ComputationMode")
// kComputeVtOverwriteXPartialU is not implemented
.value("kComputeFullUVt", svd::ComputationMode::kComputeFullUVt)
.value("kComputeMinUVt", svd::ComputationMode::kComputeMinUVt)
.value("kNoComputeUVt", svd::ComputationMode::kNoComputeUVt);
nb::enum_<eig::ComputationMode>(eig, "ComputationMode")
.value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors)
.value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors);

// Old-style LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
Expand Down Expand Up @@ -353,6 +387,14 @@ NB_MODULE(_lapack, m) {
nb::arg("m"), nb::arg("n"), nb::arg("mode"));
m.def("zgesdd_work_size_ffi", &svd::SVDType<DataType::C128>::GetWorkspaceSize,
nb::arg("m"), nb::arg("n"), nb::arg("mode"));
m.def("syevd_work_size_ffi", BoundWithEigvecs<eig::GetWorkspaceSize>,
nb::arg("n"));
m.def("syevd_iwork_size_ffi", BoundWithEigvecs<eig::GetIntWorkspaceSize>,
nb::arg("n"));
m.def("heevd_work_size_ffi", BoundWithEigvecs<eig::GetComplexWorkspaceSize>,
nb::arg("n"));
m.def("heevd_rwork_size_ffi", BoundWithEigvecs<eig::GetRealWorkspaceSize>,
nb::arg("n"));
}

} // namespace
Expand Down
Loading

0 comments on commit b2a469b

Please sign in to comment.