diff --git a/src/mssm/models.py b/src/mssm/models.py index 4718738..78a9810 100644 --- a/src/mssm/models.py +++ b/src/mssm/models.py @@ -1,13 +1,11 @@ import numpy as np import scipy as scp -import multiprocessing as mp -from itertools import repeat import copy from collections.abc import Callable from .src.python.formula import Formula,PFormula,PTerm,build_sparse_matrix_from_formula,VarType,lhs,ConstType,Constraint from .src.python.exp_fam import Link,Logit,Family,Binomial,Gaussian from .src.python.sem import anneal_temps_zero,const_temps,compute_log_probs,pre_ll_sms_gamm,se_step_sms_gamm,decode_local,se_step_sms_dc_gamm,pre_ll_sms_IR_gamm,init_states_IR -from .src.python.gamm_solvers import solve_gamm_sparse +from .src.python.gamm_solvers import solve_gamm_sparse,mp,repeat from .src.python.terms import TermType,GammTerm,i,f,fs,irf,l,li,ri,rs from .src.python.penalties import PenType @@ -216,7 +214,7 @@ def print_smooth_terms(self,pen_cutoff=0.2): ##################################### Fitting ##################################### - def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,restart=False,progress_bar=True): + def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,restart=False,progress_bar=True,n_cores=10): """ Fit the specified model. @@ -279,7 +277,7 @@ def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,res model_mat,penalties,self.formula.n_coef, self.family,maxiter,"svd", conv_tol,extend_lambda,control_lambda, - progress_bar) + progress_bar,n_cores) self.__coef = coef self.__scale = scale # ToDo: scale name is used in another context for more general mssm.. @@ -663,7 +661,8 @@ def fit(self,burn_in=100,maxiter_inner=30,m_avg=15,conv_tol=1e-7,extend_lambda=T coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,state_y, model_mat,penalties[j],self.formula.n_coef, self.family,maxiter_inner,"svd", - conv_tol,extend_lambda,control_lambda,False) + conv_tol,extend_lambda,control_lambda, + False,self.cpus) @@ -976,7 +975,8 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,y_flat[NOT_NA_flat], model_mat_full,penalties,self.formula.n_coef, self.family,maxiter_inner,"svd", - conv_tol,extend_lambda,control_lambda,False) + conv_tol,extend_lambda,control_lambda, + False,self.cpus) # For state proposals we can utilize a temparature schedule. See sMsGamm.fit(). if schedule == "anneal": @@ -1055,7 +1055,8 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,y_flat[NOT_NA_flat], model_mat_full,penalties,self.formula.n_coef, self.family,maxiter_inner,"svd", - conv_tol,extend_lambda,control_lambda,False) + conv_tol,extend_lambda,control_lambda, + False,self.cpus) # Next update all sojourn time distribution parameters diff --git a/src/mssm/src/cpp/cpp_solvers.cpp b/src/mssm/src/cpp/cpp_solvers.cpp index e301c74..c7609de 100644 --- a/src/mssm/src/cpp/cpp_solvers.cpp +++ b/src/mssm/src/cpp/cpp_solvers.cpp @@ -191,13 +191,12 @@ std::tuple,Eigen::VectorXi,Eigen::VectorXd,int> solv return std::make_tuple(solver.matrixL(),P.indices(),std::move(coef),0); } -Eigen::SparseMatrix solve_tr(Eigen::SparseMatrix L,Eigen::SparseMatrix P,Eigen::SparseMatrix D){ - // Solves L * B = P * D for B. - // The sum of squares over B corresponds to the trace necessary for the edf/scale parameter computation (Wood & Fasiolo, 2017). - // This computation needs to be repeated for all smooth terms (i.e., all D) but can be completely parallelized. This will however - // only be worth it if the individual solves take time - big models. - Eigen::SparseMatrix B = P * D; - L.triangularView().solveInPlace(B); +Eigen::SparseMatrix solve_tr(Eigen::SparseMatrix A,Eigen::SparseMatrix B){ + // Solves A*B=C, where B is lower triangular. This can be utilized to obtain B = inv(A), when C is + // the identity. Importantly, when A is a n*n matrix then C can also be specified as a n*m block of + // the identity. In that case, inv(A) can be obtained in parallel. + + A.triangularView().solveInPlace(B); return B; } @@ -209,5 +208,5 @@ PYBIND11_MODULE(cpp_solvers, m) { m.def("solve_am", &solve_am, "Solve additive model, return coefficient vector and inverse"); m.def("solve_L", &solve_L, "Solve cholesky of XX+S"); m.def("solve_coef", &solve_coef, "Solve additive model coefficients"); - m.def("solve_tr",&solve_tr,"Solve for trace matrix required for lambda update."); + m.def("solve_tr",&solve_tr,"Solve A*B = C, where A is lower triangular."); } \ No newline at end of file diff --git a/src/mssm/src/python/gamm_solvers.py b/src/mssm/src/python/gamm_solvers.py index d238460..9e64967 100644 --- a/src/mssm/src/python/gamm_solvers.py +++ b/src/mssm/src/python/gamm_solvers.py @@ -1,5 +1,7 @@ import numpy as np import scipy as scp +import multiprocessing as mp +from itertools import repeat import warnings from .exp_fam import Family,Gaussian from .penalties import PenType,id_dist_pen,translate_sparse @@ -21,8 +23,8 @@ def cpp_solve_coef(y,X,S): def cpp_solve_L(X,S): return cpp_solvers.solve_L(X,S) -def cpp_solve_tr(L,P,D): - return cpp_solvers.solve_tr(L,P,D) +def cpp_solve_tr(A,B): + return cpp_solvers.solve_tr(A,B) def step_fellner_schall_sparse(gInv,emb_SJ,Bps,cCoef,cLam,scale,verbose=False): # Compute a generalized Fellner Schall update step for a lambda term. This update rule is @@ -202,6 +204,63 @@ def apply_eigen_perm(Pr,InvCholXXSP): InvCholXXS = InvCholXXSP @ Perm return InvCholXXS +def compute_B_mp(L,PD): + B = cpp_solve_tr(L,PD) + return B.power(2).sum() + +def compute_B(L,P,lTerm,n_c=10): + # Solves L @ B = P @ D for B, parallelizing over column + # blocks of D if int(D.shape[1]/1000) > 1 + + # D is extremely sparse and P only shuffles rows, so we + # can take only the columns which we know contain non-zero elements + D_start = lTerm.start_index + D_len = lTerm.rep_sj * lTerm.S_J.shape[1] + D_end = lTerm.start_index + D_len + + D_r = int(D_len/1000) + if D_r > 1 and n_c > 1: + # Parallelize + n_c = min(D_r,n_c) + split = np.array_split(range(D_start,D_end),n_c) + PD = P @ lTerm.D_J_emb + PDs = [PD[:,split[i]] for i in range(n_c)] + + with mp.Pool(processes=n_c) as pool: + args = zip(repeat(L),PDs) + + pow_sums = pool.starmap(compute_B_mp,args) + return sum(pow_sums) + + B = cpp_solve_tr(L,P @ lTerm.D_J_emb[:,D_start:D_end]) + return B.power(2).sum() + +def compute_Linv(L,n_c=10): + # Solves L @ inv(L) = I for Binv(L) parallelizing over column + # blocks of I if int(I.shape[1]/10000) > 1 + + n_col = L.shape[1] + r = int(n_col/10000) + T = scp.sparse.eye(n_col,format='csc') + if r > 1 and n_c > 1: + # Parallelize over column blocks of I + # Can speed up computations considerably and is feasible memory-wise + # since L itself is super sparse. + + n_c = min(r,n_c) + split = np.array_split(range(n_col),n_c) + LBs = [T[:,split[i]] for i in range(n_c)] + + with mp.Pool(processes=n_c) as pool: + args = zip(repeat(L),LBs) + + LBinvs = pool.starmap(cpp_solve_tr,args) + + return scp.sparse.hstack(LBinvs) + + return cpp_solve_tr(L,T) + + def calculate_edf(InvCholXXS,penalties,colsX): total_edf = colsX Bs = [] @@ -286,13 +345,16 @@ def update_scale_edf(y,z,eta,Wr,rowsX,colsX,InvCholXXSP,Pr,family,penalties): return wres,InvCholXXS,total_edf,term_edfs,Bs,scale -def update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,X,Xb,family,S_emb,penalties): +def update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,X,Xb,family,S_emb,penalties,n_c): # Solves the additive model for a given set of weights and penalty - InvCholXXSP, Pr, coef, code = cpp_solve_am(yb,Xb,S_emb) + LP, Pr, coef, code = cpp_solve_coef(yb,Xb,S_emb) if code != 0: raise ArithmeticError(f"Solving for coef failed with code {code}. Model is likely unidentifiable.") + # Solve for inverse of Chol factor of XX+S + InvCholXXSP = compute_Linv(LP,n_c) + # Update mu & eta eta = (X @ coef).reshape(-1,1) mu = eta @@ -306,10 +368,12 @@ def update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,X,Xb,family,S_emb,penalties): def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, maxiter=10,pinv="svd",conv_tol=1e-7, - extend_lambda=True,control_lambda=True,progress_bar=False): + extend_lambda=True,control_lambda=True, + progress_bar=False,n_c=10): # Estimates a penalized Generalized additive mixed model, following the steps outlined in Wood (2017) # "Generalized Additive Models for Gigadata" + n_c = min(mp.cpu_count(),n_c) rowsX,colsX = X.shape coef = None n_coef = None @@ -348,7 +412,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, total_edf,\ term_edfs,\ Bs,scale,wres = update_coef_and_scale(y,yb,z,Wr,rowsX,colsX, - X,Xb,family,S_emb,penalties) + X,Xb,family,S_emb,penalties,n_c) # Deviance under these starting coefficients # As well as penalized deviance @@ -467,7 +531,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, total_edf,\ term_edfs,\ Bs,scale,wres = update_coef_and_scale(y,yb,z,Wr,rowsX,colsX, - X,Xb,family,S_emb,penalties) + X,Xb,family,S_emb,penalties,n_c) # Compute gradient of REML with respect to lambda # to check if step size needs to be reduced. @@ -524,7 +588,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, total_edf,\ term_edfs,\ Bs,scale,wres = update_coef_and_scale(y,yb,z,Wr,rowsX,colsX, - X,Xb,family,S_emb,penalties) + X,Xb,family,S_emb,penalties,n_c) # Final penalty if len(penalties) > 0: diff --git a/tutorials/1) GAMMs.ipynb b/tutorials/1) GAMMs.ipynb index aa4af21..afe6747 100644 --- a/tutorials/1) GAMMs.ipynb +++ b/tutorials/1) GAMMs.ipynb @@ -807,7 +807,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Converged!: 66%|██████▌ | 33/50 [00:01<00:00, 24.74it/s] " + "Converged!: 66%|██████▌ | 33/50 [00:01<00:00, 25.78it/s] " ] }, { @@ -935,7 +935,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Converged!: 68%|██████▊ | 34/50 [00:09<00:04, 3.40it/s] " + "Converged!: 68%|██████▊ | 34/50 [00:09<00:04, 3.77it/s] " ] }, { @@ -1105,7 +1105,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Converged!: 54%|█████▍ | 27/50 [00:01<00:01, 19.41it/s] " + "Converged!: 54%|█████▍ | 27/50 [00:01<00:01, 19.65it/s] " ] }, { @@ -1277,7 +1277,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Converged!: 28%|██▊ | 14/50 [00:05<00:13, 2.58it/s] " + "Converged!: 28%|██▊ | 14/50 [00:05<00:12, 2.79it/s] " ] }, { @@ -1426,7 +1426,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Converged!: 58%|█████▊ | 29/50 [00:11<00:08, 2.60it/s] " + "Converged!: 58%|█████▊ | 29/50 [00:10<00:07, 2.83it/s] " ] }, { @@ -1597,7 +1597,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Converged!: 48%|████▊ | 24/50 [00:27<00:29, 1.14s/it] " + "Converged!: 48%|████▊ | 24/50 [00:27<00:29, 1.13s/it] " ] }, { @@ -1816,7 +1816,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Converged!: 32%|███▏ | 16/50 [00:19<00:42, 1.24s/it] " + "Converged!: 32%|███▏ | 16/50 [00:19<00:41, 1.23s/it] " ] }, {