diff --git a/src/mssm/src/cpp/cpp_solvers.cpp b/src/mssm/src/cpp/cpp_solvers.cpp index 234566a..6002117 100644 --- a/src/mssm/src/cpp/cpp_solvers.cpp +++ b/src/mssm/src/cpp/cpp_solvers.cpp @@ -107,6 +107,44 @@ std::tuple,Eigen::SparseMatrix,Eigen::Vector } +std::tuple, int, int> solve_pqr(int Arows, int Acols, int Annz, + py::array_t Adata, + py::array_t Aidptr, + py::array_t Aindices){ + + Eigen::Map> A(Arows,Acols,Annz, + (Eigen::SparseMatrix::StorageIndex*) Aidptr.data(), + (Eigen::SparseMatrix::StorageIndex*) Aindices.data(), + (Eigen::SparseMatrix::Scalar*) Adata.data()); + + // Computed column-pivoted QR factorization of A and solve A @ B = I for B (inverse of A) + Eigen::SparseQR,Eigen::AMDOrdering> solver; + solver.compute(A); + + // Also setup identity target for inverse of A + Eigen::SparseMatrix id(Acols,Acols); + id.setIdentity(); + + if(solver.info()!=Eigen::Success) + { + + return std::make_tuple(std::move(id),0,1); + } + + // see: https://eigen.tuxfamily.org/dox/classEigen_1_1SparseQR.html + Eigen::SparseMatrix invA(Acols,Acols); + invA = solver.solve(id); + + if(solver.info()!=Eigen::Success) + { + + return std::make_tuple(std::move(id),0,1); + } + + return std::make_tuple(std::move(invA),solver.rank(),0); + +} + std::tuple,Eigen::VectorXi,Eigen::VectorXd,int> solve_am(Eigen::VectorXd y, int Xrows, int Xcols, int Xnnz, py::array_t Xdata, py::array_t Xidptr, @@ -420,6 +458,7 @@ PYBIND11_MODULE(cpp_solvers, m) { m.def("chol", &chol, "Compute cholesky factor L of A"); m.def("cholP", &cholP, "Compute cholesky factor L of A after applying a sparsity enhancing permutation to A"); m.def("pqr", &pqr, "Perform column pivoted QR decomposition of A"); + m.def("solve_pqr", &solve_pqr, "Perform column pivoted QR decomposition of A, then solve for inverse of A"); m.def("solve_am", &solve_am, "Solve additive model, return coefficient vector and inverse"); m.def("solve_L", &solve_L, "Solve cholesky of XX+S"); m.def("solve_LXX", &solve_LXX, "Solve cholesky of XX+S, but with XX + S pre-computed."); diff --git a/src/mssm/src/python/gamm_solvers.py b/src/mssm/src/python/gamm_solvers.py index bef6cc0..54f2083 100644 --- a/src/mssm/src/python/gamm_solvers.py +++ b/src/mssm/src/python/gamm_solvers.py @@ -21,6 +21,9 @@ def cpp_cholP(A): def cpp_qr(A): return cpp_solvers.pqr(*map_csc_to_eigen(A)) +def cpp_solve_qr(A): + return cpp_solvers.solve_pqr(*map_csc_to_eigen(A)) + def cpp_solve_am(y,X,S): return cpp_solvers.solve_am(y,*map_csc_to_eigen(X),*map_csc_to_eigen(S)) @@ -1417,11 +1420,6 @@ def solve_gamm_sparse2(formula:Formula,penalties,col_S,family:Family, if not term_edfs is None: term_edfs = calculate_term_edf(penalties,term_edfs) - if InvCholXXS is None: - Lp, Pr, _ = cpp_cholP((XX+S_emb).tocsc()) - InvCholXXSP = compute_Linv(Lp,n_c) - InvCholXXS = apply_eigen_perm(Pr,InvCholXXSP) - clear_cache(CACHE_DIR,SHOULD_CACHE) return coef,eta,wres,scale,InvCholXXS,total_edf,term_edfs,penalty