Skip to content

Commit

Permalink
Set up simple GAM tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JoKra1 committed Mar 12, 2024
1 parent b142d9c commit 791bce6
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 6 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ requires = [
]
build-backend = "setuptools.build_meta"

[tool.cibuildwheel]
test-requires = "pytest"
test-command = "pytest {project}/tests"

[project]
dependencies=["numpy >= 1.24.1",
"pandas >= 1.5.3",
Expand All @@ -28,4 +32,4 @@ dynamic = ["version"]

[tool.setuptools_scm]
# https://github.com/pypa/setuptools_scm/issues/342
local_scheme = "no-local-version"
local_scheme = "no-local-version"
2 changes: 1 addition & 1 deletion src/mssm/src/python/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def map_csc_to_eigen(X):
see: https://github.com/fwilliams/numpyeigen/blob/master/src/npe_sparse_array.h#L74
"""

if X.getformat() != "csc":
if X.format != "csc":
raise TypeError(f"Format of sparse matrix passed to c++ MUST be 'csc' but is {X.getformat()}")

rows, cols = X.shape
Expand Down
73 changes: 73 additions & 0 deletions tests/test_gamm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from mssm.models import *
import numpy as np

class Test_GAM:

dat = pd.read_csv("./tutorials/data/GAMM/sim_dat.csv")

# mssm requires that the data-type for variables used as factors is 'O'=object
dat = dat.astype({'series': 'O',
'cond':'O',
'sub':'O',
'series':'O'})

formula = Formula(lhs=lhs("y"), # The dependent variable - here y!
terms=[i(), # The intercept, a
f(["time"])], # The f(time) term, by default parameterized with 9 basis functions (after absorbing one for identifiability)
data=dat,
print_warn=False)

model = GAMM(formula,Gaussian())

model.fit()

def test_GAMedf(self):
assert round(self.model.edf,ndigits=3) == 9.723

def test_GAMTermEdf(self):
assert round(self.model.term_edf[0],ndigits=3) == 8.723

def test_GAMsigma(self):
_, sigma = self.model.get_pars()
assert round(sigma,ndigits=3) == 1084.879

def test_GAMlam(self):
assert round(self.model.formula.penalties[0].lam,ndigits=5) == 0.0089


class Test_GAMM:

dat = pd.read_csv("./tutorials/data/GAMM/sim_dat.csv")

# mssm requires that the data-type for variables used as factors is 'O'=object
dat = dat.astype({'series': 'O',
'cond':'O',
'sub':'O',
'series':'O'})

formula = Formula(lhs=lhs("y"), # The dependent variable - here y!
terms=[i(), # The intercept, a
l(["cond"]), # For cond='b'
f(["time"],by="cond",constraint=ConstType.QR), # to-way interaction between time and cond; one smooth over time per cond level
f(["x"],by="cond",constraint=ConstType.QR), # to-way interaction between x and cond; one smooth over x per cond level
f(["time","x"],by="cond",constraint=ConstType.QR), # three-way interaction
fs(["time"],rf="sub")], # Random non-linear effect of time - one smooth per level of factor sub
data=dat,
print_warn=False)

model = GAMM(formula,Gaussian())

model.fit()

def test_GAMMedf(self):
assert round(self.model.edf,ndigits=3) == 153.601

def test_GAMMTermEdf(self):
assert np.array_equal(np.round(self.model.term_edf,decimals=3),np.array([6.892, 8.635, 1.181, 1.001, 1.001, 1.029, 131.861])) == True

def test_GAMMsigma(self):
_, sigma = self.model.get_pars()
assert round(sigma,ndigits=3) == 577.196

def test_GAMMlam(self):
assert np.array_equal(np.round([p.lam for p in self.model.formula.penalties],decimals=3),np.array([0.004, 0.006, 5842.507, 1101786.56 , 328846.811, 174267.629, 162215.095, 1178.787, 0.119, 2.166])) == True
37 changes: 33 additions & 4 deletions tutorials/1) GAMMs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,26 @@
"y = model.formula.y_flat[model.formula.NOT_NA_flat] # The dependent variable after NAs were removed"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0089"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"round(model.formula.penalties[0].lam,ndigits=5)"
]
},
{
"cell_type": "code",
"execution_count": 11,
Expand Down Expand Up @@ -1248,9 +1268,18 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/joshmac/Documents/repos/mssm/src/mssm/src/python/formula.py:1106: UserWarning: 3003 y values (9.32%) are NA.\n",
" warnings.warn(f\"{data.shape[0] - data[NAs_flat].shape[0]} {self.get_lhs().variable} values ({round((data.shape[0] - data[NAs_flat].shape[0]) / data.shape[0] * 100,ndigits=2)}%) are NA.\")\n"
]
}
],
"source": [
"formula7 = Formula(lhs=lhs(\"y\"), # The dependent variable - here y!\n",
" terms=[i(), # The intercept, a\n",
Expand All @@ -1264,7 +1293,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand All @@ -1278,7 +1307,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Converged!: 38%|███▊ | 19/50 [00:07<00:12, 2.55it/s] "
"Converged!: 38%|███▊ | 19/50 [00:06<00:10, 2.87it/s] "
]
},
{
Expand Down

0 comments on commit 791bce6

Please sign in to comment.