Skip to content

Commit

Permalink
Fixes to llk eval + better description for fitting routines
Browse files Browse the repository at this point in the history
  • Loading branch information
JoKra1 committed Jun 11, 2024
1 parent cf0c259 commit 4e3236f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/mssm/src/python/exp_fam.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def lp(self,y,mu):

def llk(self,y,mu):
# y is observed proportion of success
return sum(self.lp(y,mu))
return sum(self.lp(y,mu))[0]

def deviance(self,y,mu):
D = 2 * (self.__max_llk - self.llk(y,mu))
Expand All @@ -124,7 +124,7 @@ def lp(self,y,mu,sigma=1):
return scp.stats.norm.logpdf(y,loc=mu,scale=math.sqrt(sigma))

def llk(self,y,mu,sigma = 1):
return sum(self.lp(y,mu,sigma))
return sum(self.lp(y,mu,sigma))[0]

def deviance(self,y,mu):
# Based on Faraway (2016)
Expand Down
65 changes: 43 additions & 22 deletions src/mssm/src/python/gamm_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,15 @@ def correct_coef_step(coef,n_coef,dev,pen_dev,c_dev_prev,family,eta,mu,y,X,n_pen
# previous coefficient
n_coef = coef

# Step halving
n_coef = (coef + n_coef)/2

# Update mu & eta for correction
# Note, Wood (2017) show pseudo-data and weight computation in
# step 1 - which should be re-visited after the correction, but because
# mu and eta can change during the correction (due to step halving) and neither
# the pseudo-data nor the weights are necessary to compute the deviance it makes
# sense to only compute these once **after** we have completed the coef corrections.
if formula is None:
eta = (X @ n_coef).reshape(-1,1)
else:
Expand Down Expand Up @@ -797,12 +803,12 @@ def correct_lambda_step(y,yb,z,Wr,rowsX,colsX,X,Xb,
formula,form_Linv)

# Compute gradient of REML with respect to lambda
# to check if step size needs to be reduced.
# to check if step size needs to be reduced (part of step 6 in Wood, 2017).
lam_grad = [grad_lambda(lgdetDs[lti],Bs[lti],bsbs[lti],scale) for lti in range(len(penalties))]
lam_grad = np.array(lam_grad).reshape(-1,1)
check = lam_grad.T @ lam_delta

if check[0,0] < 0 and control_lambda: # because of minimization in Wood (2017) they use a different check.
if check[0,0] < 0 and control_lambda: # because of minimization in Wood (2017) they use a different check (step 7) but idea is the same.
# Reset extension or cut the step taken in half
for lti,lTerm in enumerate(penalties):
if extend_lambda:
Expand All @@ -821,7 +827,7 @@ def correct_lambda_step(y,yb,z,Wr,rowsX,colsX,X,Xb,
if extend_lambda and lam_checks == 0:
extend_by = adapt_extension_strategy(extend_by,False,dev_check,extension_method_lam)

# Accept the step and propose a new one as well!
# Accept the step and propose a new one as well! (part of step 6 in Wood, 2017; here uses efs from Wood & Fasiolo, 2017 to propose new lambda delta)
lam_accepted = True
lam_delta = []
for lti,(lGrad,lTerm) in enumerate(zip(lam_grad,penalties)):
Expand Down Expand Up @@ -855,8 +861,8 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
extend_lambda=True,control_lambda=True,
exclude_lambda=False,extension_method_lam = "nesterov",
form_Linv=True,progress_bar=False,n_c=10):
# Estimates a penalized Generalized additive mixed model, following the steps outlined in Wood (2017)
# "Generalized Additive Models for Gigadata"
# Estimates a penalized Generalized additive mixed model, following the steps outlined in Wood, Li, Shaddick, & Augustin (2017)
# "Generalized Additive Models for Gigadata" referred to as Wood (2017) below.

n_c = min(mp.cpu_count(),n_c)
rowsX,colsX = X.shape
Expand Down Expand Up @@ -895,16 +901,11 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
fit_info = Fit_info()
for o_iter in iterator:

# We need the previous deviance and penalized deviance
# for step control and convergence control respectively
prev_dev = dev
prev_pen_dev = pen_dev

if o_iter > 0:

# Obtain deviance and penalized deviance terms
# under current lambda for proposed coef (n_coef)
# and current coef.
# under **current** lambda for proposed coef (n_coef)
# and current coef. (see Step 3 in Wood, 2017)
dev = family.deviance(y,mu)
pen_dev = dev
c_dev_prev = prev_dev
Expand All @@ -916,7 +917,12 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
# Perform step-length control for the coefficients (Step 3 in Wood, 2017)
dev,pen_dev,mu,eta,coef = correct_coef_step(coef,n_coef,dev,pen_dev,c_dev_prev,family,eta,mu,y,X,len(penalties),S_emb,None,n_c)

# Test for convergence (Step 2 in Wood, 2017)
# Test for convergence (Step 2 in Wood, 2017), implemented based on step 4 in Wood, Goude, & Shaw (2016): Generalized
# additive models for large data-sets. They reccomend inspecting the change in deviance after a PQL iteration to monitor
# convergence. Wood (2017) check the REML gradient against a fraction of the current deviance so to determine whether the change
# in deviance is "small" enough, it is also compared to a fraction of the current deviance. mgcv's bam function also considers this
# for convergence decisions but as part of a larger composite criterion, also involving checks on the scale parameter for example.
# From simulations, I get the impression that the simple criterion proposed by WGS seems to suffice.
dev_diff = abs(pen_dev - prev_pen_dev)

if progress_bar:
Expand All @@ -929,18 +935,25 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
fit_info.code = 0
break

# Update pseudo-dat weights for next coefficient step
# Update pseudo-dat weights for next coefficient step (step 1 in Wood, 2017; but moved after the coef correction because z and Wr depend on
# mu and eta, which change during the correction but anything that needs to be computed during the correction (deviance) does not depend on
# z and Wr).
yb,Xb,z,Wr = update_PIRLS(y,yb,mu,eta,X,Xb,family)

# We need the deviance and penalized deviance of the model at this point (before completing steps 5-7 (dev_{old} in WGS used for convergence control)
# for coef step control (step 3 in Wood, 2017) and convergence control (step 2 in Wood, 2017 based on step 4 in Wood, Goude, & Shaw, 2016) respectively
prev_dev = dev
prev_pen_dev = pen_dev

# Step length control for proposed lambda change
# Step length control for proposed lambda change (steps 5-7 in Wood, 2017) adjusted to make use of EFS from Wood & Fasiolo, 2017
if len(penalties) > 0:

# Test the lambda update
# Test the lambda update (step 5 in Wood, 2017)
for lti,lTerm in enumerate(penalties):
lTerm.lam += lam_delta[lti][0]
#print(lTerm.lam,lam_delta[lti][0])

# Now check step length and compute lambda + coef update.
# Now check step length and compute lambda + coef update. (steps 6-7 in Wood, 2017)
dev_check = None
if o_iter > 0:
dev_check = dev_diff < 1e-3*pen_dev
Expand All @@ -967,6 +980,14 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
X,Xb,family,S_emb,None,None,
penalties,n_c,None,form_Linv)

# At this point we:
# - have corrected & accepted the lam_deltas added above (step 5)
# - have proposed new coefficients (n_coef)
# - have updated eta and mu to reflect these new coef
# - have assigned the deviance before completing steps 5-7 to prev_dev
# - have proposed new lambda deltas (lam_delta)
#
# This completes step 8 in Wood (2017)!
fit_info.iter += 1

# Final penalty
Expand Down Expand Up @@ -1357,11 +1378,6 @@ def solve_gamm_sparse2(formula:Formula,penalties,col_S,family:Family,
fit_info = Fit_info()
for o_iter in iterator:

# We need the previous deviance and penalized deviance
# for step control and convergence control respectively
prev_dev = dev
prev_pen_dev = pen_dev

if o_iter > 0:

# Obtain deviance and penalized deviance terms
Expand Down Expand Up @@ -1390,6 +1406,11 @@ def solve_gamm_sparse2(formula:Formula,penalties,col_S,family:Family,
iterator.close()
fit_info.code = 0
break

# We need the deviance and penalized deviance of the model at this point (before completing steps 5-7 (dev_{old} in WGS used for convergence control)
# for coef step control (step 3 in Wood, 2017) and convergence control (step 2 in Wood, 2017 based on step 4 in Wood, Goude, & Shaw, 2016) respectively
prev_dev = dev
prev_pen_dev = pen_dev

# Step length control for proposed lambda change
if len(penalties) > 0:
Expand Down

0 comments on commit 4e3236f

Please sign in to comment.