Skip to content

Commit

Permalink
Re-worked lambda calculation to be more efficient using pre-computed …
Browse files Browse the repository at this point in the history
…rank
  • Loading branch information
JoKra1 committed Jan 24, 2024
1 parent 0acecb3 commit 48c37aa
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 34 deletions.
94 changes: 67 additions & 27 deletions src/mssm/src/python/gamm_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,26 @@ def cpp_solve_L(X,S):
def cpp_solve_tr(A,B):
return cpp_solvers.solve_tr(A,B)

def step_fellner_schall_sparse(gInv,emb_SJ,Bps,cCoef,cLam,scale,verbose=False):
def compute_lgdetD_bsb(rank,cLam,gInv,emb_SJ,cCoef):
# Derivative of log(|S_lambda|+), the log of the "Generalized determinant", with respect to lambda see Wood, Shaddick, & Augustin, (2017)
# and Wood & Fasiolo (2016), and Wood (2017), and Wood (2020)
if not rank is None:
# (gInv @ emb_SJ).trace() should be equal to rank(S_J)/cLam for single penalty terms (Wood, 2020)
lgdet_deriv = rank/cLam
else:
lgdet_deriv = (gInv @ emb_SJ).trace()

# Derivative of log(|XX + S_lambda|) is computed elsewhere, but we need the remaining part from the LLK (Wood & Fasiolo, 2016):
bSb = cCoef.T @ emb_SJ @ cCoef
return lgdet_deriv,bSb

def step_fellner_schall_sparse(lgdet_deriv,ldet_deriv,bSb,cLam,scale):
# Compute a generalized Fellner Schall update step for a lambda term. This update rule is
# discussed in Wood & Fasiolo (2016) and used here because of it's efficiency.
# ToDo: (gInv @ emb_SJ).trace() should be equal to rank(S_J)/cLam for single penalty terms (Wood, 2020)
num = max(0,(gInv @ emb_SJ).trace() - Bps)
denom = max(0,cCoef.T @ emb_SJ @ cCoef)

num = max(0,lgdet_deriv - ldet_deriv)

denom = max(0,bSb)

# Especially when using Null-penalties denom can realisitically become
# equal to zero: every coefficient of a term is penalized away. In that
Expand All @@ -44,18 +58,15 @@ def step_fellner_schall_sparse(gInv,emb_SJ,Bps,cCoef,cLam,scale,verbose=False):
nLam = max(nLam,1e-7) # Prevent Lambda going to zero
nLam = min(nLam,1e+7) # Prevent overflow

if verbose:
print(f"Num = {(gInv @ emb_SJ).trace()} - {Bps} == {num}\nDenom = {denom}; Lambda = {nLam}")

# Compute lambda delta
delta_lam = nLam-cLam

return delta_lam

def grad_lambda(gInv,emb_SJ,Bps,cCoef,scale):
def grad_lambda(lgdet_deriv,ldet_deriv,bSb,scale):
# P. Deriv of restricted likelihood with respect to lambda.
# From Wood & Fasiolo (2016)
return (gInv @ emb_SJ).trace()/2 - Bps/2 - (cCoef.T @ emb_SJ @ cCoef) / (2*scale)
return lgdet_deriv/2 - ldet_deriv/2 - bSb / (2*scale)

def compute_S_emb_pinv_det(col_S,penalties,pinv):
# Computes final S multiplied with lambda
Expand Down Expand Up @@ -114,13 +125,15 @@ def compute_S_emb_pinv_det(col_S,penalties,pinv):
S_pinv_cols = []
cIndexPinv = SJ_idx[0]

FS_use_rank = []
for SJi in range(SJ_idx_len):
# Now handle all pinv calculations because all penalties
# associated with a term have been collected in SJ

if SJ_terms[SJi] == 1 and SJ_types[SJi] == PenType.IDENTITY:
#print("Identity shortcut",SJ_lams[SJi])
SJ_pinv_elements,SJ_pinv_rows,SJ_pinv_cols,_,_,_ = id_dist_pen(SJs[SJi].shape[1],lambda x: 1/SJ_lams[SJi])

if SJ_terms[SJi] == 1:
cIndexPinv += (SJs[SJi].shape[1]*SJ_reps[SJi])
FS_use_rank.append(True)

else:
# Compute pinv(SJ) via cholesky factor L so that L @ L' = SJ' @ SJ.
# If SJ is full rank, then pinv(SJ) = inv(L)' @ inv(L) @ SJ'.
Expand Down Expand Up @@ -153,18 +166,25 @@ def compute_S_emb_pinv_det(col_S,penalties,pinv):
SJ_pinv = scp.sparse.csc_array(SJ_pinv)

SJ_pinv_elements,SJ_pinv_rows,SJ_pinv_cols = translate_sparse(SJ_pinv)

SJ_pinv_rows = np.array(SJ_pinv_rows)
SJ_pinv_cols = np.array(SJ_pinv_cols)

for _ in range(SJ_reps[SJi]):
S_pinv_elements.extend(SJ_pinv_elements)
S_pinv_rows.extend(SJ_pinv_rows + cIndexPinv)
S_pinv_cols.extend(SJ_pinv_cols + cIndexPinv)
cIndexPinv += (SJ_pinv_cols[-1] + 1)
SJ_pinv_rows = np.array(SJ_pinv_rows)
SJ_pinv_cols = np.array(SJ_pinv_cols)

for _ in range(SJ_reps[SJi]):
S_pinv_elements.extend(SJ_pinv_elements)
S_pinv_rows.extend(SJ_pinv_rows + cIndexPinv)
S_pinv_cols.extend(SJ_pinv_cols + cIndexPinv)
cIndexPinv += (SJ_pinv_cols[-1] + 1)

for _ in range(SJ_terms[SJi]):
FS_use_rank.append(False)

S_pinv = scp.sparse.csc_array((S_pinv_elements,(S_pinv_rows,S_pinv_cols)),shape=(col_S,col_S))
return S_emb, S_pinv

if len(FS_use_rank) != len(penalties):
raise IndexError("An incorrect number of rank decisions were made.")

return S_emb, S_pinv, FS_use_rank

def PIRLS_pdat_weights(y,mu,eta,family:Family):
# Compute pseudo-data and weights for Penalized Reweighted Least Squares iteration (Wood, 2017, 6.1.1)
Expand Down Expand Up @@ -394,7 +414,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,

# Compute starting estimate S_emb and S_pinv
if len(penalties) > 0:
S_emb,S_pinv = compute_S_emb_pinv_det(col_S,penalties,pinv)
S_emb,S_pinv,FS_use_rank = compute_S_emb_pinv_det(col_S,penalties,pinv)
else:
S_emb = scp.sparse.csc_array((colsX, colsX), dtype=np.float64)

Expand Down Expand Up @@ -429,12 +449,19 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
if len(penalties) > 0:
lam_delta = []
for lti,lTerm in enumerate(penalties):
dLam = step_fellner_schall_sparse(S_pinv,lTerm.S_J_emb,Bs[lti],coef,lTerm.lam,scale)

lt_rank = None
if FS_use_rank[lti]:
lt_rank = lTerm.rank

lgdetD,bsb = compute_lgdetD_bsb(lt_rank,lTerm.lam,S_pinv,lTerm.S_J_emb,coef)
dLam = step_fellner_schall_sparse(lgdetD,Bs[lti],bsb,lTerm.lam,scale)

if extend_lambda:
extension = lTerm.lam + dLam*extend_by
if extension < 1e7 and extension > 1e-7: # Keep lambda in correct space
dLam *= extend_by

lam_delta.append(dLam)

lam_delta = np.array(lam_delta).reshape(-1,1)
Expand Down Expand Up @@ -523,7 +550,7 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
while not lam_accepted:

# Re-compute S_emb and S_pinv
S_emb,S_pinv = compute_S_emb_pinv_det(col_S,penalties,pinv)
S_emb,S_pinv,FS_use_rank = compute_S_emb_pinv_det(col_S,penalties,pinv)

# Update coefficients
eta,mu,n_coef,\
Expand All @@ -535,7 +562,19 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,

# Compute gradient of REML with respect to lambda
# to check if step size needs to be reduced.
lam_grad = [grad_lambda(S_pinv,penalties[lti].S_J_emb,Bs[lti],n_coef,scale) for lti in range(len(penalties))]
lgdetDs = []
bsbs = []
for lti,lTerm in enumerate(penalties):

lt_rank = None
if FS_use_rank[lti]:
lt_rank = lTerm.rank

lgdetD,bsb = compute_lgdetD_bsb(lt_rank,lTerm.lam,S_pinv,lTerm.S_J_emb,n_coef)
lgdetDs.append(lgdetD)
bsbs.append(bsb)

lam_grad = [grad_lambda(lgdetDs[lti],Bs[lti],bsbs[lti],scale) for lti in range(len(penalties))]
lam_grad = np.array(lam_grad).reshape(-1,1)
check = lam_grad.T @ lam_delta

Expand Down Expand Up @@ -565,7 +604,8 @@ def solve_gamm_sparse(mu_init,y,X,penalties,col_S,family:Family,
for lti,(lGrad,lTerm) in enumerate(zip(lam_grad,penalties)):

if np.abs(lGrad[0]) >= 1e-8*np.sum(np.abs(lam_grad)):
dLam = step_fellner_schall_sparse(S_pinv,lTerm.S_J_emb,Bs[lti],n_coef,lTerm.lam,scale)

dLam = step_fellner_schall_sparse(lgdetDs[lti],Bs[lti],bsbs[lti],lTerm.lam,scale)

if extend_lambda:
extension = lTerm.lam + dLam*extend_by
Expand Down
14 changes: 7 additions & 7 deletions tutorials/1) GAMMs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/joshmac/Documents/repos/mssm/src/mssm/src/python/formula.py:673: UserWarning: 3003 y values (9.32%) are NA.\n",
"/Users/joshmac/Documents/repos/mssm/src/mssm/src/python/formula.py:680: UserWarning: 3003 y values (9.32%) are NA.\n",
" warnings.warn(f\"{data.shape[0] - data[NAs_flat].shape[0]} {self.get_lhs().variable} values ({round((data.shape[0] - data[NAs_flat].shape[0]) / data.shape[0] * 100,ndigits=2)}%) are NA.\")\n"
]
}
Expand Down Expand Up @@ -421,7 +421,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/joshmac/Documents/repos/mssm/src/mssm/src/python/formula.py:673: UserWarning: 3003 y values (9.32%) are NA.\n",
"/Users/joshmac/Documents/repos/mssm/src/mssm/src/python/formula.py:680: UserWarning: 3003 y values (9.32%) are NA.\n",
" warnings.warn(f\"{data.shape[0] - data[NAs_flat].shape[0]} {self.get_lhs().variable} values ({round((data.shape[0] - data[NAs_flat].shape[0]) / data.shape[0] * 100,ndigits=2)}%) are NA.\")\n"
]
}
Expand Down Expand Up @@ -807,7 +807,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 66%|██████ | 33/50 [00:01<00:00, 25.78it/s] "
"Converged!: 68%|██████ | 34/50 [00:01<00:00, 27.11it/s] "
]
},
{
Expand Down Expand Up @@ -935,7 +935,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 68%|██████▊ | 34/50 [00:09<00:04, 3.77it/s] "
"Converged!: 68%|██████▊ | 34/50 [00:08<00:04, 3.84it/s] "
]
},
{
Expand Down Expand Up @@ -1105,7 +1105,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 54%|█████▍ | 27/50 [00:01<00:01, 19.65it/s] "
"Converged!: 54%|█████▍ | 27/50 [00:01<00:01, 20.06it/s] "
]
},
{
Expand Down Expand Up @@ -1277,7 +1277,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 28%|██▊ | 14/50 [00:05<00:12, 2.79it/s] "
"Converged!: 28%|██▊ | 14/50 [00:04<00:12, 2.80it/s] "
]
},
{
Expand Down Expand Up @@ -1426,7 +1426,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 58%|█████▊ | 29/50 [00:10<00:07, 2.83it/s] "
"Converged!: 58%|█████▊ | 29/50 [00:09<00:07, 2.91it/s] "
]
},
{
Expand Down

0 comments on commit 48c37aa

Please sign in to comment.