Skip to content

Commit

Permalink
Parallelized Inverse computation
Browse files Browse the repository at this point in the history
  • Loading branch information
JoKra1 committed Jan 23, 2024
1 parent b05f397 commit fc106b9
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 31 deletions.
17 changes: 9 additions & 8 deletions src/mssm/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import numpy as np
import scipy as scp
import multiprocessing as mp
from itertools import repeat
import copy
from collections.abc import Callable
from .src.python.formula import Formula,PFormula,PTerm,build_sparse_matrix_from_formula,VarType,lhs,ConstType,Constraint
from .src.python.exp_fam import Link,Logit,Family,Binomial,Gaussian
from .src.python.sem import anneal_temps_zero,const_temps,compute_log_probs,pre_ll_sms_gamm,se_step_sms_gamm,decode_local,se_step_sms_dc_gamm,pre_ll_sms_IR_gamm,init_states_IR
from .src.python.gamm_solvers import solve_gamm_sparse
from .src.python.gamm_solvers import solve_gamm_sparse,mp,repeat
from .src.python.terms import TermType,GammTerm,i,f,fs,irf,l,li,ri,rs
from .src.python.penalties import PenType

Expand Down Expand Up @@ -216,7 +214,7 @@ def print_smooth_terms(self,pen_cutoff=0.2):

##################################### Fitting #####################################

def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,restart=False,progress_bar=True):
def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,restart=False,progress_bar=True,n_cores=10):
"""
Fit the specified model.
Expand Down Expand Up @@ -279,7 +277,7 @@ def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,res
model_mat,penalties,self.formula.n_coef,
self.family,maxiter,"svd",
conv_tol,extend_lambda,control_lambda,
progress_bar)
progress_bar,n_cores)

self.__coef = coef
self.__scale = scale # ToDo: scale name is used in another context for more general mssm..
Expand Down Expand Up @@ -663,7 +661,8 @@ def fit(self,burn_in=100,maxiter_inner=30,m_avg=15,conv_tol=1e-7,extend_lambda=T
coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,state_y,
model_mat,penalties[j],self.formula.n_coef,
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda,False)
conv_tol,extend_lambda,control_lambda,
False,self.cpus)



Expand Down Expand Up @@ -976,7 +975,8 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True
coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,y_flat[NOT_NA_flat],
model_mat_full,penalties,self.formula.n_coef,
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda,False)
conv_tol,extend_lambda,control_lambda,
False,self.cpus)

# For state proposals we can utilize a temparature schedule. See sMsGamm.fit().
if schedule == "anneal":
Expand Down Expand Up @@ -1055,7 +1055,8 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True
coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,y_flat[NOT_NA_flat],
model_mat_full,penalties,self.formula.n_coef,
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda,False)
conv_tol,extend_lambda,control_lambda,
False,self.cpus)

# Next update all sojourn time distribution parameters

Expand Down
15 changes: 7 additions & 8 deletions src/mssm/src/cpp/cpp_solvers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,12 @@ std::tuple<Eigen::SparseMatrix<double>,Eigen::VectorXi,Eigen::VectorXd,int> solv
return std::make_tuple(solver.matrixL(),P.indices(),std::move(coef),0);
}

Eigen::SparseMatrix<double> solve_tr(Eigen::SparseMatrix<double> L,Eigen::SparseMatrix<double> P,Eigen::SparseMatrix<double> D){
// Solves L * B = P * D for B.
// The sum of squares over B corresponds to the trace necessary for the edf/scale parameter computation (Wood & Fasiolo, 2017).
// This computation needs to be repeated for all smooth terms (i.e., all D) but can be completely parallelized. This will however
// only be worth it if the individual solves take time - big models.
Eigen::SparseMatrix<double> B = P * D;
L.triangularView<Eigen::Lower>().solveInPlace(B);
Eigen::SparseMatrix<double> solve_tr(Eigen::SparseMatrix<double> A,Eigen::SparseMatrix<double> B){
// Solves A*B=C, where B is lower triangular. This can be utilized to obtain B = inv(A), when C is
// the identity. Importantly, when A is a n*n matrix then C can also be specified as a n*m block of
// the identity. In that case, inv(A) can be obtained in parallel.

A.triangularView<Eigen::Lower>().solveInPlace(B);
return B;
}

Expand All @@ -209,5 +208,5 @@ PYBIND11_MODULE(cpp_solvers, m) {
m.def("solve_am", &solve_am, "Solve additive model, return coefficient vector and inverse");
m.def("solve_L", &solve_L, "Solve cholesky of XX+S");
m.def("solve_coef", &solve_coef, "Solve additive model coefficients");
m.def("solve_tr",&solve_tr,"Solve for trace matrix required for lambda update.");
m.def("solve_tr",&solve_tr,"Solve A*B = C, where A is lower triangular.");
}
80 changes: 72 additions & 8 deletions src/mssm/src/python/gamm_solvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import scipy as scp
import multiprocessing as mp
from itertools import repeat
import warnings
from .exp_fam import Family,Gaussian
from .penalties import PenType,id_dist_pen,translate_sparse
Expand All @@ -21,8 +23,8 @@ def cpp_solve_coef(y,X,S):
def cpp_solve_L(X,S):
return cpp_solvers.solve_L(X,S)

def cpp_solve_tr(L,P,D):
return cpp_solvers.solve_tr(L,P,D)
def cpp_solve_tr(A,B):
return cpp_solvers.solve_tr(A,B)

def step_fellner_schall_sparse(gInv,emb_SJ,Bps,cCoef,cLam,scale,verbose=False):
# Compute a generalized Fellner Schall update step for a lambda term. This update rule is
Expand Down Expand Up @@ -202,6 +204,63 @@ def apply_eigen_perm(Pr,InvCholXXSP):
InvCholXXS = InvCholXXSP @ Perm
return InvCholXXS

def compute_B_mp(L,PD):
B = cpp_solve_tr(L,PD)
return B.power(2).sum()

def compute_B(L,P,lTerm,n_c=10):
# Solves L @ B = P @ D for B, parallelizing over column
# blocks of D if int(D.shape[1]/1000) > 1

# D is extremely sparse and P only shuffles rows, so we
# can take only the columns which we know contain non-zero elements
D_start = lTerm.start_index
D_len = lTerm.rep_sj * lTerm.S_J.shape[1]
D_end = lTerm.start_index + D_len

D_r = int(D_len/1000)
if D_r > 1 and n_c > 1:
# Parallelize
n_c = min(D_r,n_c)
split = np.array_split(range(D_start,D_end),n_c)
PD = P @ lTerm.D_J_emb
PDs = [PD[:,split[i]] for i in range(n_c)]

with mp.Pool(processes=n_c) as pool:
args = zip(repeat(L),PDs)

pow_sums = pool.starmap(compute_B_mp,args)
return sum(pow_sums)

B = cpp_solve_tr(L,P @ lTerm.D_J_emb[:,D_start:D_end])
return B.power(2).sum()

def compute_Linv(L,n_c=10):
# Solves L @ inv(L) = I for Binv(L) parallelizing over column
# blocks of I if int(I.shape[1]/10000) > 1

n_col = L.shape[1]
r = int(n_col/10000)
T = scp.sparse.eye(n_col,format='csc')
if r > 1 and n_c > 1:
# Parallelize over column blocks of I
# Can speed up computations considerably and is feasible memory-wise
# since L itself is super sparse.

n_c = min(r,n_c)
split = np.array_split(range(n_col),n_c)
LBs = [T[:,split[i]] for i in range(n_c)]

with mp.Pool(processes=n_c) as pool:
args = zip(repeat(L),LBs)

LBinvs = pool.starmap(cpp_solve_tr,args)

return scp.sparse.hstack(LBinvs)

return cpp_solve_tr(L,T)


def calculate_edf(InvCholXXS,penalties,colsX):
total_edf = colsX
Bs = []
Expand Down Expand Up @@ -286,13 +345,16 @@ def update_scale_edf(y,z,eta,Wr,rowsX,colsX,InvCholXXSP,Pr,family,penalties):

return wres,InvCholXXS,total_edf,term_edfs,Bs,scale

def update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,X,Xb,family,S_emb,penalties):
def update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,X,Xb,family,S_emb,penalties,n_c):
# Solves the additive model for a given set of weights and penalty
InvCholXXSP, Pr, coef, code = cpp_solve_am(yb,Xb,S_emb)
LP, Pr, coef, code = cpp_solve_coef(yb,Xb,S_emb)

if code != 0:
raise ArithmeticError(f"Solving for coef failed with code {code}. Model is likely unidentifiable.")

# Solve for inverse of Chol factor of XX+S
InvCholXXSP = compute_Linv(LP,n_c)

# Update mu & eta
eta = (X @ coef).reshape(-1,1)
mu = eta
Expand All @@ -306,10 +368,12 @@ def update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,X,Xb,family,S_emb,penalties):

def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
maxiter=10,pinv="svd",conv_tol=1e-7,
extend_lambda=True,control_lambda=True,progress_bar=False):
extend_lambda=True,control_lambda=True,
progress_bar=False,n_c=10):
# Estimates a penalized Generalized additive mixed model, following the steps outlined in Wood (2017)
# "Generalized Additive Models for Gigadata"

n_c = min(mp.cpu_count(),n_c)
rowsX,colsX = X.shape
coef = None
n_coef = None
Expand Down Expand Up @@ -348,7 +412,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
total_edf,\
term_edfs,\
Bs,scale,wres = update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,
X,Xb,family,S_emb,penalties)
X,Xb,family,S_emb,penalties,n_c)

# Deviance under these starting coefficients
# As well as penalized deviance
Expand Down Expand Up @@ -467,7 +531,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
total_edf,\
term_edfs,\
Bs,scale,wres = update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,
X,Xb,family,S_emb,penalties)
X,Xb,family,S_emb,penalties,n_c)

# Compute gradient of REML with respect to lambda
# to check if step size needs to be reduced.
Expand Down Expand Up @@ -524,7 +588,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
total_edf,\
term_edfs,\
Bs,scale,wres = update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,
X,Xb,family,S_emb,penalties)
X,Xb,family,S_emb,penalties,n_c)

# Final penalty
if len(penalties) > 0:
Expand Down
14 changes: 7 additions & 7 deletions tutorials/1) GAMMs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 66%|██████▌ | 33/50 [00:01<00:00, 24.74it/s] "
"Converged!: 66%|██████▌ | 33/50 [00:01<00:00, 25.78it/s] "
]
},
{
Expand Down Expand Up @@ -935,7 +935,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 68%|██████▊ | 34/50 [00:09<00:04, 3.40it/s] "
"Converged!: 68%|██████▊ | 34/50 [00:09<00:04, 3.77it/s] "
]
},
{
Expand Down Expand Up @@ -1105,7 +1105,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 54%|█████▍ | 27/50 [00:01<00:01, 19.41it/s] "
"Converged!: 54%|█████▍ | 27/50 [00:01<00:01, 19.65it/s] "
]
},
{
Expand Down Expand Up @@ -1277,7 +1277,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 28%|██▊ | 14/50 [00:05<00:13, 2.58it/s] "
"Converged!: 28%|██▊ | 14/50 [00:05<00:12, 2.79it/s] "
]
},
{
Expand Down Expand Up @@ -1426,7 +1426,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 58%|█████▊ | 29/50 [00:11<00:08, 2.60it/s] "
"Converged!: 58%|█████▊ | 29/50 [00:10<00:07, 2.83it/s] "
]
},
{
Expand Down Expand Up @@ -1597,7 +1597,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 48%|████▊ | 24/50 [00:27<00:29, 1.14s/it] "
"Converged!: 48%|████▊ | 24/50 [00:27<00:29, 1.13s/it] "
]
},
{
Expand Down Expand Up @@ -1816,7 +1816,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 32%|███▏ | 16/50 [00:19<00:42, 1.24s/it] "
"Converged!: 32%|███▏ | 16/50 [00:19<00:41, 1.23s/it] "
]
},
{
Expand Down

0 comments on commit fc106b9

Please sign in to comment.