Skip to content

Commit

Permalink
Support for whole-interval CI coverage and conditional posterior samp…
Browse files Browse the repository at this point in the history
…ling
  • Loading branch information
JoKra1 committed Jul 3, 2024
1 parent c5c21b4 commit b920c57
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 3 deletions.
94 changes: 92 additions & 2 deletions src/mssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .src.python.gamm_solvers import solve_gamm_sparse,mp,repeat,tqdm,cpp_cholP,apply_eigen_perm,compute_Linv,solve_gamm_sparse2
from .src.python.terms import TermType,GammTerm,i,f,fs,irf,l,li,ri,rs
from .src.python.penalties import PenType
from .src.python.utils import sample_MVN

##################################### Base Class #####################################

Expand Down Expand Up @@ -329,12 +330,87 @@ def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,exc

##################################### Prediction #####################################

def predict(self, use_terms, n_dat,alpha=0.05,ci=False):
def sample_post(self,n_ps,use_post=None,deviations=False,seed=None):
"""
Obtain ``n_ps`` samples from posterior [\boldsymbol{\beta} - \hat{\boldsymbol{\beta}}] | y,\lambda ~ N(0,V),
where V is [X.T@X + S_\lambda]^{-1}*scale (see Wood, 2017; section 6.10). To obtain samples for \boldsymbol{\beta},
simply set ``deviations`` to false.
see ``sample_MVN`` in ``mssm.src.python.utils.py`` for more details.
References:
- Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.).
Parameters:
:param use_post: The indices corresponding to coefficients for which to actually obtain samples.
"""
if deviations:
post = sample_MVN(n_ps,0,self.__scale,P=None,L=None,LI=self.lvi,use=use_post,seed=seed)
else:
post = sample_MVN(n_ps,self.__coef,self.__scale,P=None,L=None,LI=self.lvi,use=use_post,seed=seed)

return post

def __adjust_CI(self,n_ps,b,predi_mat,use_terms,alpha,seed):
"""
Adjusts point-wise CI to behave like whole-interval (based on Wood, 2017; section 6.10.2 and Simpson, 2016):
self.coef +- b gives point-wise interval, so for interval to be whole-interval
1-alpha% of posterior samples should fall completely within these boundaries.
From section 6.10 in Wood (2017) we have that *coef | y, lambda ~ N(coef,V), where V is self.lvi.T @ self.lvi * self.__scale.
Implication is that deviations [*coef - coef] | y, lambda ~ N(0,V). In line with definition above, 1-alpha% of
predi_mat@[*coef - coef] should fall within [b,-b]. Wood (2017) suggests to find a so that [a*b,a*-b] achieves this.
To do this, we find a for every predi_mat@[*coef - coef] and then select the final one so that 1-alpha% of samples had an equal or lower
one. The consequence: 1-alpha% of samples drawn should fall completely within the modified boundaries.
References:
- Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.).
- Simpson, G. (2016). Simultaneous intervals for smooths revisited.
"""
use_post = None
if not use_terms is None:
# If we have many random factor levels, but want to make predictions only
# for fixed effects, it's wasteful to sample all coefficients from posterior.
# The code below performs a selection of the coefficients to be sampled.
use_post = predi_mat.sum(axis=0) != 0
use_post = np.arange(0,predi_mat.shape[1])[use_post]

# Sample deviations [*coef - coef] from posterior of GAMM
post = self.sample_post(n_ps,use_post,deviations=True,seed=seed)

# To make computations easier take the abs of predi_mat@[*coef - coef], because [b,-b] is symmetric we can
# simply check whether abs(predi_mat@[*coef - coef]) < b by computing abs(predi_mat@[*coef - coef])/b. The max of
# this ratio, over rows of predi_mat, is a for this sample. If a<=1, no extension is necessary for this series.
if use_post is None:
fpost = np.abs(predi_mat@post)
else:
fpost = np.abs(predi_mat[:,use_post]@post)

# Compute ratio between abs(predi_mat@[*coef - coef])/b for every sample.
fpost = fpost / b[:,None]

# Then compute max of this ratio, over rows of predi_mat, for every sample. np.max(fpost,axis=0) now is a vector
# with n_ps elements, holding for each sample the multiplicative adjustment a, necessary to ensure that predi_mat@[*coef - coef]
# falls completely between [a*b,a*-b].
# The final multiplicative adjustment bmadq is selected from this vector to be high enough so that in 1-(alpha) simulations
# we have an equal or lower a.
bmadq = np.quantile(np.max(fpost,axis=0),1-alpha)

# Then adjust b
b *= bmadq

return b

def predict(self, use_terms, n_dat,alpha=0.05,ci=False,whole_interval=False,n_ps=10000,seed=None):
"""
Make a prediction using the fitted model for new data ``n_dat`` using only the terms indexed by ``use_terms``.
References:
- Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.).
- Simpson, G. (2016). Simultaneous intervals for smooths revisited.
Parameters:
Expand Down Expand Up @@ -401,18 +477,26 @@ def predict(self, use_terms, n_dat,alpha=0.05,ci=False):
c = predi_mat @ self.lvi.T @ self.lvi * self.__scale @ predi_mat.T
c = c.diagonal()
b = scp.stats.norm.ppf(1-(alpha/2)) * np.sqrt(c)

# Whole-interval CI (section 6.10.2 in Wood, 2017), the same idea was also
# explored by Simpson (2016) who performs very similar computations to compute
# such intervals. See __adjust_CI function.
if whole_interval:
b = self.__adjust_CI(n_ps,b,predi_mat,use_terms,alpha,seed)

return pred,predi_mat,b

return pred,predi_mat,None

def predict_diff(self,dat1,dat2,use_terms,alpha=0.05):
def predict_diff(self,dat1,dat2,use_terms,alpha=0.05,whole_interval=False,n_ps=10000,seed=None):
"""
Get the difference in the predictions for two datasets. Useful to compare a smooth estimated for
one level of a factor to the smooth estimated for another level of a factor. In that case, ``dat1`` and
``dat2`` should only differ in the level of said factor.
References:
- Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.).
- Simpson, G. (2016). Simultaneous intervals for smooths revisited.
- ``get_difference`` function from ``itsadug`` R-package: https://rdrr.io/cran/itsadug/man/get_difference.html
Parameters:
Expand Down Expand Up @@ -447,6 +531,12 @@ def predict_diff(self,dat1,dat2,use_terms,alpha=0.05):
c = c.diagonal()
b = scp.stats.norm.ppf(1-(alpha/2)) * np.sqrt(c)

# Whole-interval CI (section 6.10.2 in Wood, 2017), the same idea was also
# explored by Simpson (2016) who performs very similar computations to compute
# such intervals. See __adjust_CI function.
if whole_interval:
b = self.__adjust_CI(n_ps,b,pmat_diff,use_terms,alpha,seed)

return diff,b

class sMsGAMM(MSSM):
Expand Down
2 changes: 1 addition & 1 deletion src/mssm/src/python/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def reparam(X,S,cov,option=1,n_bins=30,QR=False,identity=False,scale=False):
# Now decompose X = Q @ R
if QR:
_,R = scp.linalg.qr(X.toarray(),mode='economic')
R = scp.sparse.csc_array(R)
R = scp.sparse.csr_array(R)

else:
XX = (X.T @ X).tocsc()
Expand Down
91 changes: 91 additions & 0 deletions src/mssm/src/python/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import scipy as scp
import math
import warnings
from ..python.gamm_solvers import cpp_backsolve_tr

def sample_MVN(n,mu,scale,P,L,LI=None,use=None,seed=None):
"""
Draw n samples x from multivariate normal with mean mu and covariance matrix Sigma so that Sigma/scale = LI.T@LI, LI = L^{-1}, and
finally L@L.T = {Sigma/scale}^{-1}. In other words, L*(1/scale)^{0.5} is the cholesky for the precision matrix corresponding to Sigma.
Notably, L (and LI) have actually be computed for P@[X.T@X+S_\lambda]@P.T (see Wood \& Fasiolo, 2017), hence for sampling we need to correct
for permutation matrix ``P``. if ``LI`` is provided, then ``P`` can be omitted and is assumed to have been applied to ``LI already``
Used to sample the uncorrected posterior \beta|y,\lambda ~ N(\boldsymbol{\beta},(X.T@X+S_\lambda)^{-1}\phi) for a GAMM (see Wood, 2017).
Based on section 7.4 in Gentle (2009), assuming Sigma is p*p and covariance matrix of uncorrected posterior:
x = mu + P.T@LI.T*scale^{0.5}@z where z_i ~ N(0,1) for all i = 1,...,p
Notably, we can rely on the fact of equivalence that:
L.T*(1/scale)^{0.5} @ P@x = z
...and then first solve for y in:
L.T*(1/scale)^{0.5} @ y = z
...followed by computing:
y = P@x
x = P.T@y
The latter allows to avoid forming L^{-1} (which unlike L might not benefit from the sparsity preserving permutation P). Hence, if ``LI is None``,
``L`` will be used for sampling.
Often we care only about a handfull of elements in mu (usually the first ones corresponding to "fixed effects'" in a GAMM). In that case we
can generate x only for this sub-set of interest by only using a row-block of L/LI (all columns remain). Argument ``use`` can be a Numpy array
containg the indices of elements in mu that should be sampled. Because this only works efficiently when ``LI`` is available an error is raised
when ``not use is None and LI is None``.
If ``mu`` is set to any integer (i.e., not a Numpy array/list) it is treated as 0.
References:
- Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.).
- Gentle, J. (2009). Computational Statistics.
"""
if L is None and LI is None:
raise ValueError("Either ``L`` or ``LI`` have to be provided.")

if not L is None and not LI is None:
warnings.warn("Both ``L`` and ``LI`` were provided, will rely on ``LI``.")

if not L is None and LI is None and P is None:
raise ValueError("When sampling with ``L`` ``P`` must be provided.")

if not use is None and LI is None:
raise ValueError("If ``use`` is not None ``LI`` must be provided.")

# Correct for scale
if not LI is None:
Cs = LI.T*math.sqrt(scale)
else:
Cs = L.T*math.sqrt(1/scale)

# Sample from N(0,1)
z = scp.stats.norm.rvs(size=Cs.shape[1]*n,random_state=seed).reshape(Cs.shape[1],n)

# Sample with L
if LI is None:
z = cpp_backsolve_tr(Cs.tocsc(),scp.sparse.csc_array(z)).toarray() # actually y

if isinstance(mu,int):
return P.T@z

return mu[:,None] + P.T@z

else:
# Sample with LI
if not P is None:
Cs = P.T@Cs

if not use is None:
Cs = Cs[use,:]

if isinstance(mu,int):
return Cs@z

if not use is None:
mus = mu[use,None]
else:
mus = mu[:,None]

return mus + Cs@z

0 comments on commit b920c57

Please sign in to comment.