Skip to content

Commit

Permalink
Added fitting progress indication to GAMM
Browse files Browse the repository at this point in the history
  • Loading branch information
JoKra1 committed Jan 16, 2024
1 parent 75bdaf0 commit 0e69825
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 28 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ build-backend = "setuptools.build_meta"
[project]
dependencies=["numpy >= 1.24.1",
"pandas >= 1.5.3",
"scipy >= 1.10.0"]
"scipy >= 1.10.0",
"tqdm >= 4.66.1"]
name = "mssm"
authors = [
{ name="Joshua Krause", email="jokra001@proton.me" }
Expand Down
11 changes: 6 additions & 5 deletions src/mssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def print_smooth_terms(self,pen_cutoff=0.2):

##################################### Fitting #####################################

def fit(self,maxiter=30,conv_tol=1e-7,extend_lambda=True,control_lambda=True,restart=False):
def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,restart=False,progress_bar=True):
"""
Fit the specified model.
Expand Down Expand Up @@ -278,7 +278,8 @@ def fit(self,maxiter=30,conv_tol=1e-7,extend_lambda=True,control_lambda=True,res
coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,y_flat,
model_mat,penalties,self.formula.n_coef,
self.family,maxiter,"svd",
conv_tol,extend_lambda,control_lambda)
conv_tol,extend_lambda,control_lambda,
progress_bar)

self.__coef = coef
self.__scale = scale # ToDo: scale name is used in another context for more general mssm..
Expand Down Expand Up @@ -662,7 +663,7 @@ def fit(self,burn_in=100,maxiter_inner=30,m_avg=15,conv_tol=1e-7,extend_lambda=T
coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,state_y,
model_mat,penalties[j],self.formula.n_coef,
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda)
conv_tol,extend_lambda,control_lambda,False)



Expand Down Expand Up @@ -975,7 +976,7 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True
coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,y_flat[NOT_NA_flat],
model_mat_full,penalties,self.formula.n_coef,
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda)
conv_tol,extend_lambda,control_lambda,False)

# For state proposals we can utilize a temparature schedule. See sMsGamm.fit().
if schedule == "anneal":
Expand Down Expand Up @@ -1054,7 +1055,7 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True
coef,eta,wres,scale,LVI,edf,term_edf,penalty = solve_gamm_sparse(init_mu_flat,y_flat[NOT_NA_flat],
model_mat_full,penalties,self.formula.n_coef,
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda)
conv_tol,extend_lambda,control_lambda,False)

# Next update all sojourn time distribution parameters

Expand Down
17 changes: 9 additions & 8 deletions src/mssm/src/python/gamm_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .exp_fam import Family,Gaussian
from .penalties import PenType,id_dist_pen,translate_sparse
import cpp_solvers
from tqdm import tqdm

def cpp_chol(A):
return cpp_solvers.chol(A)
Expand Down Expand Up @@ -305,7 +306,7 @@ def update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,X,Xb,family,S_emb,penalties):

def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
maxiter=10,pinv="svd",conv_tol=1e-7,
extend_lambda=True,control_lambda=True):
extend_lambda=True,control_lambda=True,progress_bar=False):
# Estimates a penalized Generalized additive mixed model, following the steps outlined in Wood (2017)
# "Generalized Additive Models for Gigadata"

Expand Down Expand Up @@ -375,9 +376,11 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
lam_delta = np.array(lam_delta).reshape(-1,1)

# Loop to optimize smoothing parameter (see Wood, 2017)
converged = False
o_iter = 0
while o_iter < maxiter and not converged:
iterator = range(maxiter)
if progress_bar:
iterator = tqdm(iterator,desc="Fitting",leave=False)

for o_iter in iterator:

# We need the previous deviance and penalized deviance
# for step control and convergence control respectively
Expand Down Expand Up @@ -430,7 +433,8 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,

# Test for convergence (Step 2 in Wood, 2017)
if abs(pen_dev - prev_pen_dev) < conv_tol*pen_dev:
converged = True
if progress_bar:
iterator.close()
break

# Update pseudo-dat weights for next coefficient step
Expand Down Expand Up @@ -511,9 +515,6 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
Bs,scale,wres = update_coef_and_scale(y,yb,z,Wr,rowsX,colsX,
X,Xb,family,S_emb,penalties)

# Update number of iterations completed
o_iter += 1

# Final penalty
if len(penalties) > 0:
penalty = coef.T @ S_emb @ coef
Expand Down
202 changes: 188 additions & 14 deletions tutorials/1) GAMMs.ipynb

Large diffs are not rendered by default.

0 comments on commit 0e69825

Please sign in to comment.