Skip to content

Commit

Permalink
Support for penalty split for sub-factor of random smooth
Browse files Browse the repository at this point in the history
  • Loading branch information
JoKra1 committed Jul 9, 2024
1 parent 255c1ca commit 06a9bae
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 47 deletions.
19 changes: 12 additions & 7 deletions src/mssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,exc
self.family,maxiter,"svd",
conv_tol,extend_lambda,control_lambda,
exclude_lambda,extension_method_lam,
self.formula.discretize is None,
len(self.formula.discretize) == 0,
progress_bar,n_cores)

else:
Expand All @@ -336,7 +336,7 @@ def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,exc
self.family,maxiter,"svd",
conv_tol,extend_lambda,control_lambda,
exclude_lambda,extension_method_lam,
self.formula.discretize is None,
len(self.formula.discretize) == 0,
progress_bar,n_cores)

self.__coef = coef
Expand Down Expand Up @@ -457,10 +457,15 @@ def predict(self, use_terms, n_dat,alpha=0.05,ci=False,whole_interval=False,n_ps
"""
var_map = self.formula.get_var_map()
var_keys = var_map.keys()
sub_group_vars = self.formula.get_subgroup_variables()

for k in var_keys:
if k not in n_dat.columns:
raise IndexError(f"Variable {k} is missing in new data.")
if k in sub_group_vars:
if k.split(":")[0] not in n_dat.columns:
raise IndexError(f"Variable {k.split(':')[0]} is missing in new data.")
else:
if k not in n_dat.columns:
raise IndexError(f"Variable {k} is missing in new data.")

# Encode test data
_,pred_cov_flat,_,_,_,_,_ = self.formula.encode_data(n_dat,prediction=True)
Expand Down Expand Up @@ -859,7 +864,7 @@ def fit(self,burn_in=100,maxiter_inner=30,m_avg=15,conv_tol=1e-7,extend_lambda=T
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda,
exclude_lambda,"nesterov",
self.formula.discretize is None,
len(self.formula.discretize) == 0,
False,self.cpus)


Expand Down Expand Up @@ -1185,7 +1190,7 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda,
exclude_lambda,"nesterov",
self.formula.discretize is None,
len(self.formula.discretize) == 0,
False,self.cpus)

# For state proposals we can utilize a temparature schedule. See sMsGamm.fit().
Expand Down Expand Up @@ -1288,7 +1293,7 @@ def fit(self,maxiter_outer=100,maxiter_inner=30,conv_tol=1e-6,extend_lambda=True
self.family,maxiter_inner,"svd",
conv_tol,extend_lambda,control_lambda,
exclude_lambda,"nesterov",
self.formula.discretize is None,
len(self.formula.discretize) == 0,
False,self.cpus)

# Next update all sojourn time distribution parameters
Expand Down
Loading

0 comments on commit 06a9bae

Please sign in to comment.