From 4e3236fe6dfa85611d639e01c3dc33653c34bfa5 Mon Sep 17 00:00:00 2001 From: Joshua Krause <52180639+JoKra1@users.noreply.github.com> Date: Tue, 11 Jun 2024 10:36:14 +0200 Subject: [PATCH] Fixes to llk eval + better description for fitting routines --- src/mssm/src/python/exp_fam.py | 4 +- src/mssm/src/python/gamm_solvers.py | 65 +++++++++++++++++++---------- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/mssm/src/python/exp_fam.py b/src/mssm/src/python/exp_fam.py index d8726f5..1f7e853 100644 --- a/src/mssm/src/python/exp_fam.py +++ b/src/mssm/src/python/exp_fam.py @@ -106,7 +106,7 @@ def lp(self,y,mu): def llk(self,y,mu): # y is observed proportion of success - return sum(self.lp(y,mu)) + return sum(self.lp(y,mu))[0] def deviance(self,y,mu): D = 2 * (self.__max_llk - self.llk(y,mu)) @@ -124,7 +124,7 @@ def lp(self,y,mu,sigma=1): return scp.stats.norm.logpdf(y,loc=mu,scale=math.sqrt(sigma)) def llk(self,y,mu,sigma = 1): - return sum(self.lp(y,mu,sigma)) + return sum(self.lp(y,mu,sigma))[0] def deviance(self,y,mu): # Based on Faraway (2016) diff --git a/src/mssm/src/python/gamm_solvers.py b/src/mssm/src/python/gamm_solvers.py index 3600951..329d7f0 100644 --- a/src/mssm/src/python/gamm_solvers.py +++ b/src/mssm/src/python/gamm_solvers.py @@ -620,9 +620,15 @@ def correct_coef_step(coef,n_coef,dev,pen_dev,c_dev_prev,family,eta,mu,y,X,n_pen # previous coefficient n_coef = coef + # Step halving n_coef = (coef + n_coef)/2 # Update mu & eta for correction + # Note, Wood (2017) show pseudo-data and weight computation in + # step 1 - which should be re-visited after the correction, but because + # mu and eta can change during the correction (due to step halving) and neither + # the pseudo-data nor the weights are necessary to compute the deviance it makes + # sense to only compute these once **after** we have completed the coef corrections. if formula is None: eta = (X @ n_coef).reshape(-1,1) else: @@ -797,12 +803,12 @@ def correct_lambda_step(y,yb,z,Wr,rowsX,colsX,X,Xb, formula,form_Linv) # Compute gradient of REML with respect to lambda - # to check if step size needs to be reduced. + # to check if step size needs to be reduced (part of step 6 in Wood, 2017). lam_grad = [grad_lambda(lgdetDs[lti],Bs[lti],bsbs[lti],scale) for lti in range(len(penalties))] lam_grad = np.array(lam_grad).reshape(-1,1) check = lam_grad.T @ lam_delta - if check[0,0] < 0 and control_lambda: # because of minimization in Wood (2017) they use a different check. + if check[0,0] < 0 and control_lambda: # because of minimization in Wood (2017) they use a different check (step 7) but idea is the same. # Reset extension or cut the step taken in half for lti,lTerm in enumerate(penalties): if extend_lambda: @@ -821,7 +827,7 @@ def correct_lambda_step(y,yb,z,Wr,rowsX,colsX,X,Xb, if extend_lambda and lam_checks == 0: extend_by = adapt_extension_strategy(extend_by,False,dev_check,extension_method_lam) - # Accept the step and propose a new one as well! + # Accept the step and propose a new one as well! (part of step 6 in Wood, 2017; here uses efs from Wood & Fasiolo, 2017 to propose new lambda delta) lam_accepted = True lam_delta = [] for lti,(lGrad,lTerm) in enumerate(zip(lam_grad,penalties)): @@ -855,8 +861,8 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, extend_lambda=True,control_lambda=True, exclude_lambda=False,extension_method_lam = "nesterov", form_Linv=True,progress_bar=False,n_c=10): - # Estimates a penalized Generalized additive mixed model, following the steps outlined in Wood (2017) - # "Generalized Additive Models for Gigadata" + # Estimates a penalized Generalized additive mixed model, following the steps outlined in Wood, Li, Shaddick, & Augustin (2017) + # "Generalized Additive Models for Gigadata" referred to as Wood (2017) below. n_c = min(mp.cpu_count(),n_c) rowsX,colsX = X.shape @@ -895,16 +901,11 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, fit_info = Fit_info() for o_iter in iterator: - # We need the previous deviance and penalized deviance - # for step control and convergence control respectively - prev_dev = dev - prev_pen_dev = pen_dev - if o_iter > 0: # Obtain deviance and penalized deviance terms - # under current lambda for proposed coef (n_coef) - # and current coef. + # under **current** lambda for proposed coef (n_coef) + # and current coef. (see Step 3 in Wood, 2017) dev = family.deviance(y,mu) pen_dev = dev c_dev_prev = prev_dev @@ -916,7 +917,12 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, # Perform step-length control for the coefficients (Step 3 in Wood, 2017) dev,pen_dev,mu,eta,coef = correct_coef_step(coef,n_coef,dev,pen_dev,c_dev_prev,family,eta,mu,y,X,len(penalties),S_emb,None,n_c) - # Test for convergence (Step 2 in Wood, 2017) + # Test for convergence (Step 2 in Wood, 2017), implemented based on step 4 in Wood, Goude, & Shaw (2016): Generalized + # additive models for large data-sets. They reccomend inspecting the change in deviance after a PQL iteration to monitor + # convergence. Wood (2017) check the REML gradient against a fraction of the current deviance so to determine whether the change + # in deviance is "small" enough, it is also compared to a fraction of the current deviance. mgcv's bam function also considers this + # for convergence decisions but as part of a larger composite criterion, also involving checks on the scale parameter for example. + # From simulations, I get the impression that the simple criterion proposed by WGS seems to suffice. dev_diff = abs(pen_dev - prev_pen_dev) if progress_bar: @@ -929,18 +935,25 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, fit_info.code = 0 break - # Update pseudo-dat weights for next coefficient step + # Update pseudo-dat weights for next coefficient step (step 1 in Wood, 2017; but moved after the coef correction because z and Wr depend on + # mu and eta, which change during the correction but anything that needs to be computed during the correction (deviance) does not depend on + # z and Wr). yb,Xb,z,Wr = update_PIRLS(y,yb,mu,eta,X,Xb,family) + + # We need the deviance and penalized deviance of the model at this point (before completing steps 5-7 (dev_{old} in WGS used for convergence control) + # for coef step control (step 3 in Wood, 2017) and convergence control (step 2 in Wood, 2017 based on step 4 in Wood, Goude, & Shaw, 2016) respectively + prev_dev = dev + prev_pen_dev = pen_dev - # Step length control for proposed lambda change + # Step length control for proposed lambda change (steps 5-7 in Wood, 2017) adjusted to make use of EFS from Wood & Fasiolo, 2017 if len(penalties) > 0: - # Test the lambda update + # Test the lambda update (step 5 in Wood, 2017) for lti,lTerm in enumerate(penalties): lTerm.lam += lam_delta[lti][0] #print(lTerm.lam,lam_delta[lti][0]) - # Now check step length and compute lambda + coef update. + # Now check step length and compute lambda + coef update. (steps 6-7 in Wood, 2017) dev_check = None if o_iter > 0: dev_check = dev_diff < 1e-3*pen_dev @@ -967,6 +980,14 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family, X,Xb,family,S_emb,None,None, penalties,n_c,None,form_Linv) + # At this point we: + # - have corrected & accepted the lam_deltas added above (step 5) + # - have proposed new coefficients (n_coef) + # - have updated eta and mu to reflect these new coef + # - have assigned the deviance before completing steps 5-7 to prev_dev + # - have proposed new lambda deltas (lam_delta) + # + # This completes step 8 in Wood (2017)! fit_info.iter += 1 # Final penalty @@ -1357,11 +1378,6 @@ def solve_gamm_sparse2(formula:Formula,penalties,col_S,family:Family, fit_info = Fit_info() for o_iter in iterator: - # We need the previous deviance and penalized deviance - # for step control and convergence control respectively - prev_dev = dev - prev_pen_dev = pen_dev - if o_iter > 0: # Obtain deviance and penalized deviance terms @@ -1390,6 +1406,11 @@ def solve_gamm_sparse2(formula:Formula,penalties,col_S,family:Family, iterator.close() fit_info.code = 0 break + + # We need the deviance and penalized deviance of the model at this point (before completing steps 5-7 (dev_{old} in WGS used for convergence control) + # for coef step control (step 3 in Wood, 2017) and convergence control (step 2 in Wood, 2017 based on step 4 in Wood, Goude, & Shaw, 2016) respectively + prev_dev = dev + prev_pen_dev = pen_dev # Step length control for proposed lambda change if len(penalties) > 0: